"transformer_engine/jax/csrc/extensions/utils.h" did not exist on "73c9f421c704d18b5f7973d72d790ddecc41ba8b"
Unverified Commit 962d9c53 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Scaling Enum Abstracting (#1655)



* scaling enum abstract

* rm NVTE_ from ScalingMode names

* rework scaling mode enum in grouped gemm

* fix norm sharding

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 9d4e11ea
...@@ -13,7 +13,9 @@ namespace transformer_engine { ...@@ -13,7 +13,9 @@ namespace transformer_engine {
namespace jax { namespace jax {
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) { DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size}; auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size}; auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
...@@ -27,10 +29,37 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -27,10 +29,37 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
int temp = 0; int temp = 0;
auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype); auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
auto output_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), output_shape, out_dtype);
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_trans_shape);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) {
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
}
}
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) {
auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
}
}
if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
}
TensorWrapper dummy_workspace; TensorWrapper dummy_workspace;
nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
...@@ -44,8 +73,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -44,8 +73,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
Result_Type output_buf, Result_Type output_trans_buf, Result_Type output_buf, Result_Type output_trans_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf,
int64_t scaling_mode_enum, int64_t quantize_layout_enum, bool is_dbias, JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
int64_t flatten_axis) { bool is_dbias, int64_t flatten_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
...@@ -54,7 +83,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -54,7 +83,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum); auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
...@@ -77,14 +105,14 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -77,14 +105,14 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
std::vector<size_t> workspace_shape{workspace_dims.begin(), workspace_dims.end()}; std::vector<size_t> workspace_shape{workspace_dims.begin(), workspace_dims.end()};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(scaling_mode); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
if (quantize_layout == QuantizeLayout::ROWWISE || if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data()); float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
...@@ -109,14 +137,16 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -109,14 +137,16 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
if (quantize_layout == QuantizeLayout::COLWISE || if (quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
auto &tmp_shape = auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; ? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape); output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf = auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; ? scale_inv_buf
: colwise_scale_inv_buf;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv( output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1}); std::vector<size_t>{1});
...@@ -153,7 +183,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, ...@@ -153,7 +183,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // dbias .Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout") .Attr<int64_t>("q_layout")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis"), .Attr<int64_t>("flatten_axis"),
......
...@@ -361,7 +361,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -361,7 +361,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
).value ).value
return QuantizeMeta(scale=scale, amax_history=amax_history) return QuantizeMeta(scale=scale, amax_history=amax_history)
if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
x_meta = generate_quantize_meta("x") x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel") kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad") grad_meta = generate_quantize_meta("grad")
......
...@@ -84,8 +84,8 @@ class Dequantizer: ...@@ -84,8 +84,8 @@ class Dequantizer:
) )
funcs = { funcs = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling,
ScalingMode.NVTE_MXFP8_1D_SCALING: _dq_func_block_scaling, ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling,
} }
@staticmethod @staticmethod
......
...@@ -94,15 +94,15 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: ...@@ -94,15 +94,15 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
A tuple of (bool, str) indicating support and any error message A tuple of (bool, str) indicating support and any error message
""" """
gpu_arch = get_device_compute_capability(gpu_id) gpu_arch = get_device_compute_capability(gpu_id)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
return _check_delayed_scaling_fp8_support(gpu_arch) return _check_delayed_scaling_fp8_support(gpu_arch)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _check_block_scaling_fp8_support(gpu_arch) return _check_block_scaling_fp8_support(gpu_arch)
return (False, "Unsupported scaling_mode!") return (False, "Unsupported scaling_mode!")
def is_fp8_available( def is_fp8_available(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
gpu_id=None, gpu_id=None,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
"""Check if FP8 is available for the given scaling mode and GPU. """Check if FP8 is available for the given scaling mode and GPU.
...@@ -179,9 +179,9 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode: ...@@ -179,9 +179,9 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
ValueError: If the recipe type is not supported ValueError: If the recipe type is not supported
""" """
if isinstance(fp8_recipe, recipe.DelayedScaling): if isinstance(fp8_recipe, recipe.DelayedScaling):
return ScalingMode.NVTE_DELAYED_TENSOR_SCALING return ScalingMode.DELAYED_TENSOR_SCALING
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
return ScalingMode.NVTE_MXFP8_1D_SCALING return ScalingMode.MXFP8_1D_SCALING
raise ValueError("Invalid fp8_recipe!") raise ValueError("Invalid fp8_recipe!")
...@@ -217,7 +217,7 @@ class QuantizeConfig: ...@@ -217,7 +217,7 @@ class QuantizeConfig:
FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False
IF_QUANTIZE_2X: bool = False IF_QUANTIZE_2X: bool = False
SCALING_MODE: ScalingMode = ScalingMode.NVTE_NO_SCALING SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING
# DelayedScaling # DelayedScaling
AMAX_HISTORY_LEN: int = 1024 AMAX_HISTORY_LEN: int = 1024
...@@ -253,11 +253,11 @@ class QuantizeConfig: ...@@ -253,11 +253,11 @@ class QuantizeConfig:
cls.MARGIN = 0.0 cls.MARGIN = 0.0
cls.FP8_FORMAT = recipe.Format.HYBRID cls.FP8_FORMAT = recipe.Format.HYBRID
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.FP8_2X_ACC_FPROP = False cls.FP8_2X_ACC_FPROP = False
cls.FP8_2X_ACC_DGRAD = False cls.FP8_2X_ACC_DGRAD = False
cls.FP8_2X_ACC_WGRAD = False cls.FP8_2X_ACC_WGRAD = False
cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.IF_QUANTIZE_2X = False cls.IF_QUANTIZE_2X = False
# DelayedScaling # DelayedScaling
cls.AMAX_HISTORY_LEN = 1024 cls.AMAX_HISTORY_LEN = 1024
......
...@@ -172,7 +172,7 @@ class DelayedScaleQuantizer(Quantizer): ...@@ -172,7 +172,7 @@ class DelayedScaleQuantizer(Quantizer):
amax_history: History of maximum absolute values amax_history: History of maximum absolute values
""" """
scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
...@@ -375,7 +375,7 @@ class BlockScaleQuantizer(Quantizer): ...@@ -375,7 +375,7 @@ class BlockScaleQuantizer(Quantizer):
q_layout: Quantization axis (default: ROWWISE_COLWISE) q_layout: Quantization axis (default: ROWWISE_COLWISE)
""" """
scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
def get_data_layout(self) -> str: def get_data_layout(self) -> str:
...@@ -530,8 +530,8 @@ class QuantizerFactory: ...@@ -530,8 +530,8 @@ class QuantizerFactory:
""" """
quantizer_type_map = { quantizer_type_map = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer,
ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScaleQuantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer,
} }
@staticmethod @staticmethod
...@@ -556,8 +556,9 @@ class QuantizerFactory: ...@@ -556,8 +556,9 @@ class QuantizerFactory:
A single quantizer or tuple of quantizers A single quantizer or tuple of quantizers
""" """
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
# assert scaling_mode != ScalingMode.NVTE_INVALID_SCALING assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type"
if scaling_mode in (ScalingMode.NVTE_NO_SCALING, ScalingMode.NVTE_INVALID_SCALING): # import pdb; pdb.set_trace()
if scaling_mode == ScalingMode.NO_SCALING:
quantizers = [None] * n_quantizers quantizers = [None] * n_quantizers
else: else:
quantizers = [] quantizers = []
...@@ -651,4 +652,4 @@ class QuantizerFactory: ...@@ -651,4 +652,4 @@ class QuantizerFactory:
return q_set[0] if len(q_set) == 1 else tuple(q_set) return q_set[0] if len(q_set) == 1 else tuple(q_set)
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NVTE_NO_SCALING) noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING)
...@@ -19,6 +19,8 @@ import operator ...@@ -19,6 +19,8 @@ import operator
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import JAXX_Scaling_Mode
__all__ = ["ScalingMode"] __all__ = ["ScalingMode"]
...@@ -216,25 +218,20 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -216,25 +218,20 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (*first_dim_scale_shape, *last_dim_scale_shape) return (*first_dim_scale_shape, *last_dim_scale_shape)
# (Phuong: Map the NVTEScalingMode value to the ScalingMode
@dataclass(frozen=True) @dataclass(frozen=True)
@register_pytree_node_class @register_pytree_node_class
class ScalingMode(Enum): class ScalingMode(Enum):
"""Enumeration of tensor scaling modes with their corresponding metadata implementations. """Enumeration of tensor scaling modes with their corresponding metadata implementations.
This class defines the available scaling modes for tensor quantization: This class defines the available scaling modes for tensor quantization:
- NVTE_DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- NVTE_MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- NVTE_INVALID_SCALING: Invalid scaling mode - NO_SCALING: No scaling applied
- NVTE_NO_SCALING: No scaling applied
""" """
NVTE_DELAYED_TENSOR_SCALING = 0 NO_SCALING = JAXX_Scaling_Mode.NO_SCALING
NVTE_MXFP8_1D_SCALING = 1 DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING
NVTE_INVALID_SCALING = 100 MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING
NVTE_NO_SCALING = 1000
def _get_impl(self) -> ScalingModeMetadataImpl: def _get_impl(self) -> ScalingModeMetadataImpl:
"""Get the implementation for this scaling mode. """Get the implementation for this scaling mode.
...@@ -329,8 +326,8 @@ class ScalingMode(Enum): ...@@ -329,8 +326,8 @@ class ScalingMode(Enum):
SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = { SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(),
ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)),
# WAR # WAR
ScalingMode.NVTE_NO_SCALING: DelayedScalingModeMetadataImpl(), ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(),
} }
...@@ -236,13 +236,12 @@ class ScaledTensor1x(ScaledTensor): ...@@ -236,13 +236,12 @@ class ScaledTensor1x(ScaledTensor):
data = with_sharding_constraint_by_logical_axes(self.data, axis_names) data = with_sharding_constraint_by_logical_axes(self.data, axis_names)
if self.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# TODO(Phuong): Handle padding !? # TODO(Phuong): Handle padding !?
scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names) scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names)
else: else:
scale_inv = self.scale_inv scale_inv = self.scale_inv
# TODO(Phuong): constaint padded scale_inv?
return ScaledTensor1x( return ScaledTensor1x(
data=data, data=data,
scale_inv=scale_inv, scale_inv=scale_inv,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment