Commit a207db1d authored by yuguo's avatar yuguo
Browse files
parents fbee8990 69365f88
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Config module for quantization metadata management
This module provides configuration and helper functions for managing quantization metadata
in JAX, including support for different scaling modes and datatypes.
"""
from contextlib import contextmanager
from enum import Enum
from typing import Optional, Tuple, Dict, Union
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformer_engine_jax import DType
from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import (
get_cuda_version,
get_device_compute_capability,
)
from transformer_engine.common import recipe
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
from .scaling_modes import ScalingMode
from .. import cpp_extensions as tex
__all__ = ["QuantizeConfig", "fp8_autocast", "is_fp8_available", "update_collections"]
_is_fp8_available = None
_reason_for_no_fp8 = ""
Collection = Union[Dict, FrozenDict]
def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
"""Check if delayed scaling FP8 is supported on the given GPU architecture.
Args:
gpu_arch: The GPU architecture version
Returns:
A tuple of (bool, str) indicating support and any error message
"""
if gpu_arch >= 90: # hopper and above
return True, ""
if gpu_arch < 89: # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if get_cuda_version() < 12010:
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
return True, ""
def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
"""Check if block scaling FP8 is supported on the given GPU architecture.
Args:
gpu_arch: The GPU architecture version
Returns:
A tuple of (bool, str) indicating support and any error message
"""
if gpu_arch >= 100: # blackwell and above
return True, ""
if gpu_arch < 99: # pre-blackwell
return False, "Device compute capability 9.9 or higher required for MXFP8 execution."
if get_cublasLt_version() < 120800:
return False, "CublasLt version 12.8.0 or higher required for MXFP8 execution."
if get_cuda_version() < 12010:
return False, "Cuda version 12.8 or higher required for MXFP8 execution."
if not tex.jax_version_meet_requirement("0.5.3"):
return False, "Jax version 0.5.3 or higher required for MXFP8 execution."
return True, ""
def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
"""Check if FP8 is supported for the given scaling mode and GPU.
Args:
scaling_mode: The scaling mode to check support for
gpu_id: The ID of the GPU to check
Returns:
A tuple of (bool, str) indicating support and any error message
"""
gpu_arch = get_device_compute_capability(gpu_id)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
return _check_delayed_scaling_fp8_support(gpu_arch)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
return _check_block_scaling_fp8_support(gpu_arch)
return (False, "Unsupported scaling_mode!")
def is_fp8_available(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
gpu_id=None,
) -> Tuple[bool, str]:
"""Check if FP8 is available for the given scaling mode and GPU.
Args:
scaling_mode: The scaling mode to check availability for (default: DELAYED_TENSOR_SCALING)
gpu_id: Optional GPU ID to check specific device (default: None)
Returns:
A tuple of (bool, str) indicating availability and any error message
"""
if gpu_id is not None:
return _check_fp8_support(scaling_mode, gpu_id)
global _is_fp8_available, _reason_for_no_fp8
if _is_fp8_available is None:
_is_fp8_available = {}
_reason_for_no_fp8 = {}
if scaling_mode not in _is_fp8_available:
_is_fp8_available[scaling_mode] = True
_reason_for_no_fp8[scaling_mode] = ""
# JAX doesn't provide the local GPU id.
for local_gpu_id in range(len(jax.local_devices())):
ret, msg = _check_fp8_support(scaling_mode, local_gpu_id)
if ret is False:
_is_fp8_available[scaling_mode] = ret
_reason_for_no_fp8[scaling_mode] = msg
return ret, msg
return _is_fp8_available[scaling_mode], _reason_for_no_fp8[scaling_mode]
def _format2dtypes(format_: recipe.Format):
"""Convert recipe.Format.dtype to corresponding JAX dtypes.
Args:
format_: The FP8 format to convert
Returns:
A tuple of (forward_dtype, backward_dtype) for the given format
"""
if format_ == recipe.Format.E4M3:
return jnp.float8_e4m3fn, jnp.float8_e4m3fn
if format_ == recipe.Format.E5M2:
return jnp.float8_e5m2, jnp.float8_e5m2
if format_ == recipe.Format.HYBRID:
return jnp.float8_e4m3fn, jnp.float8_e5m2
return jnp.bfloat16, jnp.bfloat16
class AmaxComputeAlgo(Enum):
"""Enumeration for AMAX computation algorithms.
Attributes:
MAX: Use maximum value for AMAX computation
MOST_RECENT: Use most recent value for AMAX computation
"""
MAX = "max"
MOST_RECENT = "most_recent"
def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
"""Convert recipe.Recipe to ScalingMode.
Args:
fp8_recipe: The FP8 recipe to convert
Returns:
The corresponding ScalingMode
Raises:
ValueError: If the recipe type is not supported
"""
if isinstance(fp8_recipe, recipe.DelayedScaling):
return ScalingMode.NVTE_DELAYED_TENSOR_SCALING
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
return ScalingMode.NVTE_MXFP8_1D_SCALING
raise ValueError("Invalid fp8_recipe!")
def update_collections(new: Collection, original: Collection) -> Collection:
"""Update collections with new values while preserving original structure.
Args:
new: New collection of values to add/update
original: Original collection to update
Returns:
Updated collection with new values merged with original
Raises:
AssertionError: If either collection is not a dict or FrozenDict
"""
assert isinstance(original, (dict, FrozenDict))
assert isinstance(new, (dict, FrozenDict))
frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
for key in new:
if key in frozen_original:
frozen_original, _ = frozen_original.pop(key)
new_coll = FrozenDict({**new, **frozen_original})
if not isinstance(original, FrozenDict):
new_coll = new_coll.unfreeze()
return new_coll
class QuantizeConfig:
"""Configuration class for quantization settings.
This class manages global quantization settings including FP8 formats,
scaling modes, and accumulation settings.
Attributes:
INITIALIZED: Whether the config has been initialized
MARGIN: Margin value for quantization
COLLECTION_NAME: Name of the collection for quantization metadata
FP8_FORMAT: FP8 format to use
FWD_DTYPE: Forward pass data type
BWD_DTYPE: Backward pass data type
FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass
FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients
FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients
IF_QUANTIZE_2X: Whether 2x quantization is enabled
SCALING_MODE: Scaling mode
AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling
AMAX_COMPUTE_ALGO: Algorithm for AMAX computation
"""
INITIALIZED = False
MARGIN: float = 0.0
COLLECTION_NAME: str = "quantize_meta"
FP8_FORMAT: recipe.Format = recipe.Format.HYBRID
FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0]
BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1]
FP8_2X_ACC_FPROP: bool = False
FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False
IF_QUANTIZE_2X: bool = False
SCALING_MODE: ScalingMode = ScalingMode.NVTE_NO_SCALING
# DelayedScaling
AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
@staticmethod
def is_fp8_enabled():
"""Check if FP8 quantization is enabled.
Returns:
bool: True if quantization is enabled, False otherwise
"""
return QuantizeConfig.INITIALIZED
@classmethod
def initialize(cls, fp8_recipe: recipe.Recipe) -> None:
"""Initialize the quantization configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls.INITIALIZED = True
cls.MARGIN = fp8_recipe.margin
cls.FP8_FORMAT = fp8_recipe.fp8_format
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = _get_scaling_mode(fp8_recipe)
cls.IF_QUANTIZE_2X = True
@classmethod
def finalize(cls) -> None:
"""Reset the quantization configuration to default values."""
cls.INITIALIZED = False
cls.MARGIN = 0.0
cls.FP8_FORMAT = recipe.Format.HYBRID
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING
cls.FP8_2X_ACC_FPROP = False
cls.FP8_2X_ACC_DGRAD = False
cls.FP8_2X_ACC_WGRAD = False
cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING
cls.IF_QUANTIZE_2X = False
# DelayedScaling
cls.AMAX_HISTORY_LEN = 1024
cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
class DelayedScalingQuantizeConfig:
"""Configuration class for delayed scaling FP8 recipe.
This class provides specific initialization and finalization for delayed scaling
FP8 quantization mode.
"""
@staticmethod
def initialize(fp8_recipe: recipe.Recipe) -> None:
"""Initialize delayed scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
Raises:
AssertionError: If recipe parameters are not supported
"""
assert fp8_recipe.amax_compute_algo in [
"max",
"most_recent",
], "DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX."
assert (
fp8_recipe.scaling_factor_compute_algo is None
), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX."
assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX."
cls = QuantizeConfig
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len
string_to_amax_compute_algo = {
"max": AmaxComputeAlgo.MAX,
"most_recent": AmaxComputeAlgo.MOST_RECENT,
}
cls.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo]
cls.FP8_2X_ACC_DGRAD = True
cls.FP8_2X_ACC_WGRAD = True
@staticmethod
def finalize() -> None:
"""Reset the delayed scaling configuration."""
QuantizeConfig.finalize()
class BlockScalingQuantizeConfig:
"""Configuration class for block scaling FP8 recipe.
This class provides specific initialization and finalization for block scaling
FP8 quantization mode.
"""
@staticmethod
def initialize(fp8_recipe: recipe.Recipe) -> None:
"""Initialize block scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls = QuantizeConfig
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = 0
@staticmethod
def finalize() -> None:
"""Reset the block scaling configuration."""
QuantizeConfig.finalize()
@contextmanager
def fp8_autocast(
enabled: bool = False,
fp8_recipe: Optional[recipe.Recipe] = None,
mesh_resource: Optional[MeshResource] = None,
) -> None:
r"""Context manager for FP8 automatic mixed precision.
This context manager enables FP8 quantization for the duration of its context.
.. code-block:: python
mesh_shape = (4, 2)
dp_mesh_axis_name = 'data_parallel'
tp_mesh_axis_name = 'tensor_parallel'
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
rules = extend_logical_axis_rules(tuple())
transformer = TransformerLayer()
with partitioning.axis_rules(rules):
pjit(transformer.init, ...)(...)
.. note::
We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`,
and :attr:`amax_compute_algo` (with value 'max' and 'most_recent') in
recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling
will trigger an assertion.
Parameters
----------
enabled: bool, default = False
Whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None
Recipe used for FP8 training.
mesh_resource: MeshResource, default = None
Specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then no data or tensor parallelism will be used.
"""
if fp8_recipe is None:
fp8_recipe = recipe.DelayedScaling()
if mesh_resource is None:
mesh_resource = MeshResource()
Config = DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
Config = BlockScalingQuantizeConfig
try:
with global_shard_guard(mesh_resource):
if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available(_get_scaling_mode(fp8_recipe))
assert fp8_available, reason_for_no_fp8
Config.initialize(fp8_recipe)
yield
finally:
Config.finalize()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Metadata classes for quantization in JAX.
This module provides classes for managing quantization metadata, including
scale factors and amax history for different tensor types.
"""
from dataclasses import dataclass
import jax.numpy as jnp
__all__ = ["QuantizeMeta", "QuantizeMetaSet"]
@dataclass
class QuantizeMeta:
"""Metadata for quantization parameters.
Attributes:
scale: The scaling factor for quantization
amax_history: History of maximum absolute values
"""
scale: jnp.ndarray
amax_history: jnp.ndarray
@dataclass
class QuantizeMetaSet:
"""Set of quantization metadata for different tensor types.
Attributes:
x: Quantization metadata for input tensors
kernel: Quantization metadata for kernel tensors
grad: Quantization metadata for gradient tensors
"""
x: QuantizeMeta
kernel: QuantizeMeta
grad: QuantizeMeta
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor quantization classes for TE/JAX.
This module provides classes and utilities for quantizing tensors in JAX.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import partial
from typing import Union, Optional
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeAxis
from .scaling_modes import ScalingMode
from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
from .helper import (
QuantizeConfig,
AmaxComputeAlgo,
)
__all__ = [
"QuantizeAxis",
"Quantizer",
"QuantizerSet",
"DelayedScaleQuantizer",
"BlockScaleQuantizer",
"QuantizerFactory",
"noop_quantizer_set",
]
@register_pytree_node_class
@dataclass
class Quantizer(ABC):
"""Base class for quantizers.
This abstract class defines the interface for tensor quantization, providing
methods for quantization and scale management.
Attributes:
q_dtype: The data type for quantized values
scaling_mode: The scaling mode to use for quantization
q_axis: The quantization axis (row-wise, column-wise, or both)
"""
q_dtype: jnp.dtype
scaling_mode: ScalingMode
q_axis: QuantizeAxis
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = ()
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis)
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Reconstruct a quantizer from its flattened representation.
Args:
aux_data: Auxiliary data containing quantizer parameters
children: Unused children data
Returns:
A reconstructed Quantizer instance
"""
return cls(*aux_data, *children)
def update(self, *args, **kwargs):
"""Update quantizer state (no-op in base class)."""
del args, kwargs
def is_2x2x(self) -> bool:
"""Check if quantizer uses both row-wise and column-wise quantization.
Returns:
True if using both row-wise and column-wise quantization
"""
return self.q_axis == QuantizeAxis.ROWWISE_COLWISE
@abstractmethod
def get_layout(self) -> str:
"""Get the data layout.
Returns:
Data layout in string format
"""
@abstractmethod
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
"""Core quantization function to be implemented by subclasses.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values, default is x.dtype
Returns:
A ScaledTensor1x containing the quantized data
"""
def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None):
"""Quantize a tensor using the internal _quantize_func().
Args:
x: Input tensor to quantize
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
if (is_rowwise and is_colwise) or self.is_2x2x():
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype)
colwise_tensor = self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
if is_colwise:
return self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype)
return self._quantize_func(x, dq_dtype=dq_dtype)
def get_scale_shapes(self, data_shape, is_padded=True):
"""Get shapes for scale tensors.
Args:
data_shape: Shape of the input tensor
is_padded: Whether to use padded shapes
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded)
def get_scale_dtype(self):
"""Get the data type for scale tensors.
Returns:
The data type for scale tensors
"""
return self.scaling_mode.get_scale_dtype()
@register_pytree_node_class
@dataclass
class DelayedScaleQuantizer(Quantizer):
"""Quantizer implementation using delayed scaling.
This quantizer uses delayed scaling mode with float32 scales and maintains
a history of maximum absolute values for dynamic scaling.
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field(
default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32)
)
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis)
return (children, aux_data)
def get_layout(self) -> str:
"""Get the data layout string.
Returns:
Data layout in string format
Raises:
ValueError: If quantization axis is invalid
"""
layout = "NT"
if self.q_axis == QuantizeAxis.ROWWISE_COLWISE:
return layout
if self.q_axis == QuantizeAxis.ROWWISE:
return layout[0]
if self.q_axis == QuantizeAxis.COLWISE:
return layout[1]
raise ValueError(f"Invalid q_axis: {self.q_axis}")
def _quantize_func(self, x: jnp.ndarray, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
"""Quantize function helper for delayed scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
Returns:
A ScaledTensor1x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
compute_dtype = self.scale.dtype
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
scaled_x = x.astype(compute_dtype) * self.scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype
# dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
# scaled_x = x * self.scale.astype(compute_dtype)
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / self.scale
self.update(jnp.max(jnp.abs(x)).reshape((1,)))
return ScaledTensorFactory.create_1x(
data=clipped_scaled_x,
scale_inv=scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
)
def quantize(self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None):
"""Quantize a tensor using the internal _quantize_func().
Args:
x: Input tensor to quantize
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
is_rowwise = (
is_rowwise
if is_rowwise is not None
else (self.q_axis == QuantizeAxis.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_axis == QuantizeAxis.COLWISE or self.is_2x2x())
)
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype)
colwise_tensor = None
if is_colwise:
colwise_tensor = ScaledTensorFactory.create_1x(
data=jnp.transpose(rowwise_tensor.data, (-1, *range(rowwise_tensor.data.ndim - 1))),
scale_inv=rowwise_tensor.scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
is_colwise=True,
layout="T",
)
if is_colwise and is_rowwise:
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
if is_colwise:
return colwise_tensor
return rowwise_tensor
@staticmethod
@jax.jit
def _update_amax_history(amax_history, new_amax):
"""Update AMAX history with new maximum value.
Args:
amax_history: Current AMAX history
new_amax: New maximum value to add
Returns:
Updated AMAX history
"""
amax_history = amax_history.at[0].set(new_amax[0])
return amax_history
@staticmethod
@partial(jax.jit, static_argnums=(2,))
def _compute_scale(amax_history, scale, q_dtype):
"""Compute new scale based on AMAX history.
Args:
amax_history: History of maximum absolute values
scale: Current scale
q_dtype: Quantization data type
Returns:
Updated scale value
"""
# 2. Calculate the current scale
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(amax_history, axis=-1, keepdims=True)
else:
amax = amax_history[0:1]
sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
scale = scale.at[0].set(sf[0])
return scale
@staticmethod
@jax.jit
def _roll_and_reset_amax_history(amax_history):
"""Roll AMAX history and reset first element.
Args:
amax_history: Current AMAX history
Returns:
Updated AMAX history
"""
updated_amax_history = jnp.roll(amax_history, -1, -1)
amax_history = updated_amax_history.at[0].set(0.0)
return amax_history
def update(self, new_amax: jnp.ndarray):
"""Update AMAX history and compute new scale.
Args:
new_amax: New maximum absolute value to add to history
"""
amax_history = self._update_amax_history(self.amax_history, new_amax)
self.scale = self._compute_scale(amax_history, self.scale, self.q_dtype)
self.amax_history = self._roll_and_reset_amax_history(amax_history)
@register_pytree_node_class
@dataclass
class BlockScaleQuantizer(Quantizer):
"""Quantizer implementation using block-based scaling.
This quantizer uses block scaling mode with FP8 scales and block-based
quantization for improved efficiency.
Attributes:
scaling_mode: Set to NVTE_MXFP8_1D_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE)
"""
scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE
def get_layout(self) -> str:
"""Get the data layout string.
Returns:
Data layout in string format
"""
if self.is_2x2x():
return "NN"
return "N"
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
"""Quantize function helper for block scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
Returns:
A ScaledTensor1x containing the quantized data
"""
# TODO(Phuong): use quantize_func from JAX
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
x_shape = x.shape
scale_shape = self.scaling_mode.get_scale_shape(x_shape, is_colwise, is_padded=False)
scale_dtype = self.scaling_mode.get_scale_dtype()
x = x.reshape(
*x_shape[:-2],
scale_shape[-2],
int(x_shape[-2] / scale_shape[-2]),
scale_shape[-1],
int(x_shape[-1] / scale_shape[-1]),
)
amax = jnp.max(jnp.abs(x), axis=(-3, -1), keepdims=True)
MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32)
scales = amax.astype(jnp.float32) / MAX
scales_q = self._cast_to_e8m0_with_rounding_up(scales)
scaled_x = x / self._e8m0_to_dtype(scales_q, jnp.float32)
clipped_x = jnp.clip(scaled_x, -MAX, MAX)
x_q = clipped_x.astype(self.q_dtype).reshape(x_shape)
scales_q = scales_q.reshape(scale_shape).view(scale_dtype)
return ScaledTensorFactory.create_1x(
x_q,
scales_q,
self.scaling_mode,
is_colwise=is_colwise,
dq_dtype=dq_dtype,
)
def _cast_to_e8m0_with_rounding_up(self, scales):
"""Cast scales to E8M0 format with rounding up.
Args:
scales: Input scales to convert
Returns:
Scales in E8M0 format
"""
temp = scales.astype(jnp.float32).view(jnp.uint32)
exp = temp >> 23
mant = temp & 0x7FFFFF
is_ru = jnp.logical_and(
jnp.logical_and((mant > 0), (exp != 0xFE)),
~jnp.logical_and((exp == 0), (mant <= 0x400000)),
)
exp = jnp.where(is_ru, exp + 1, exp)
new_scales = exp.astype(jnp.uint8)
return new_scales
def _e8m0_to_dtype(self, x, dtype):
"""Convert E8M0 format to specified data type.
Args:
x: Input in E8M0 format
dtype: Target data type
Returns:
Converted values in target data type
"""
temp = x.astype(jnp.uint32)
exp = temp << 23
new_x = exp.view(jnp.float32)
near_zero_value = 2**-15 if dtype == jnp.float16 else 2**-127
new_x = jnp.where(new_x == 0, jnp.array(near_zero_value, jnp.float32), new_x)
return new_x.astype(dtype)
@register_pytree_node_class
@dataclass
class QuantizerSet:
"""Set of quantizers for different tensor types.
This class manages quantizers for input tensors, kernel tensors, and
gradient tensors.
Attributes:
x: Quantizer for input tensors
kernel: Quantizer for kernel tensors
dgrad: Quantizer for gradient tensors
"""
x: Optional[Quantizer]
kernel: Optional[Quantizer]
dgrad: Optional[Quantizer]
def tree_flatten(self):
"""Flatten the quantizer set for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = (self.x, self.kernel, self.dgrad)
aux_data = ()
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Reconstruct a quantizer set from its flattened representation.
Args:
aux_data: Unused auxiliary data
children: Tuple of quantizers
Returns:
A reconstructed QuantizerSet instance
"""
return cls(*aux_data, *children)
@dataclass
class QuantizerFactory:
"""Factory class for creating quantizers.
This class provides static methods to create individual quantizers and
sets of quantizers with various configurations.
"""
quantizer_type_map = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScaleQuantizer,
ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScaleQuantizer,
}
@staticmethod
def create(
n_quantizers: int = 1,
scaling_mode: ScalingMode = None,
q_dtype: jnp.dtype = None,
q_axis: QuantizeAxis = None,
**kwargs,
) -> Quantizer:
"""Create one or more quantizers with specified parameters.
Args:
n_quantizers: Number of quantizers to create
scaling_mode: Scaling mode to use
q_dtype: Quantization data type
q_axis: Quantization axis
**kwargs: Additional arguments for quantizer initialization
Returns:
A single quantizer or tuple of quantizers
"""
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
# assert scaling_mode != ScalingMode.NVTE_INVALID_SCALING
if scaling_mode in (ScalingMode.NVTE_NO_SCALING, ScalingMode.NVTE_INVALID_SCALING):
quantizers = [None] * n_quantizers
else:
quantizers = []
for _ in range(n_quantizers):
quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)
quantizers.append(
quantizer_type(
q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis, **kwargs
)
)
return quantizers[0] if len(quantizers) == 1 else tuple(quantizers)
@staticmethod
def _create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) -> QuantizerSet:
"""Create a set of quantizers for forward and backward passes.
Args:
scaling_mode: Scaling mode to use
fwd_dtype: Data type for forward pass
bwd_dtype: Data type for backward pass
is_2x2x: Whether to use 2x2x quantization
**kwargs: Additional arguments for quantizer initialization
Returns:
A QuantizerSet instance
"""
if is_2x2x:
q_axis_x = q_axis_kernel = q_axis_dgrad = QuantizeAxis.ROWWISE_COLWISE
else:
q_axis_x = QuantizeAxis.ROWWISE
q_axis_kernel = QuantizeAxis.COLWISE
q_axis_dgrad = None
if "quantize_meta_set" in kwargs:
quantize_meta_set = kwargs.get("quantize_meta_set")
args_x = {
"scale": quantize_meta_set.x.scale,
"amax_history": quantize_meta_set.x.amax_history,
}
args_kernel = {
"scale": quantize_meta_set.kernel.scale,
"amax_history": quantize_meta_set.kernel.amax_history,
}
args_grad = {
"scale": quantize_meta_set.grad.scale,
"amax_history": quantize_meta_set.grad.amax_history,
}
else:
args_x = args_kernel = args_grad = {}
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_x, **args_x)
q_kernel = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_kernel, **args_kernel)
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_axis_dgrad, **args_grad)
return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)
@staticmethod
def create_set(
n_quantizer_sets: int = 1,
scaling_mode: ScalingMode = None,
fwd_dtype: jnp.dtype = None,
bwd_dtype: jnp.dtype = None,
is_2x2x: bool = None,
**kwargs,
) -> tuple[Union[tuple[Quantizer], None]]:
"""Create one or more sets of quantizers.
Args:
n_quantizer_sets: Number of quantizer sets to create
scaling_mode: Scaling mode to use, default is QuantizeConfig.SCALING_MODE
fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
**kwargs: Additional arguments for quantizer initialization
Returns:
A single quantizer set or tuple of quantizer sets
"""
scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE
bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE
is_2x2x = is_2x2x or QuantizeConfig.IF_QUANTIZE_2X
q_set = []
for _ in range(n_quantizer_sets):
q_set.append(
QuantizerFactory._create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs)
)
return q_set[0] if len(q_set) == 1 else tuple(q_set)
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NVTE_NO_SCALING)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Scaling mode implementations for quantization in JAX.
This module provides implementations of different scaling modes for tensor quantization,
including delayed scaling and block scaling strategies.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Tuple, Dict
from functools import reduce
import operator
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp
__all__ = ["ScalingMode"]
class ScalingModeMetadataImpl(ABC):
"""Base class for scaling mode implementations.
This abstract class defines the interface for different scaling mode implementations,
providing methods to get scale data types and shapes.
"""
@abstractmethod
def get_scale_dtype(self) -> jnp.dtype:
"""Get the data type for scale tensors.
Returns:
The data type used for scale tensors
"""
@abstractmethod
def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True
) -> Tuple[int, ...]:
"""Get the shape for scale tensors.
Args:
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
Returns:
The shape for scale tensors
"""
class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for delayed scaling mode.
This implementation provides metadata for delayed scaling mode, including scale data type and shape.
"""
def get_scale_dtype(self) -> jnp.dtype:
"""Get the data type for scale tensors in delayed scaling.
Returns:
The data type used for scale tensors (float32)
"""
return jnp.float32
def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True
) -> Tuple[int, ...]:
"""Get the shape for scale tensors in delayed scaling.
Args:
data_shape: The shape of the tensor being scaled
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
Returns:
The shape for scale tensors - (1,)
"""
del data_shape, is_colwise
return (1,)
class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for block scaling mode.
This implementation provides metadata for block scaling mode, which uses
block-based scaling with specific alignment requirements.
Attributes:
_block_dims: Dimensions of the scaling blocks
_block_alignment: Alignment requirements for blocks
"""
def __init__(self, block_dims: Tuple[int]):
"""Initialize block scaling mode implementation.
Args:
block_dims: Dimensions of the scaling blocks
"""
self._block_dims = block_dims
self._block_alignment = (128, 4)
def get_scale_dtype(self) -> jnp.dtype:
"""Get the data type for scale tensors in block scaling.
Returns:
The data type used for scale tensors (float8_e8m0fnu)
"""
return jnp.float8_e8m0fnu
def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True
) -> Tuple[int, ...]:
"""Get the shape for scale tensors in block scaling.
Args:
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
Returns:
The shape for scale tensors
"""
block_alignment = self._block_alignment if is_padded else (1, 1)
if is_colwise:
block_y, block_x = self._block_dims
alignment_y, alignment_x = block_alignment
else:
block_x, block_y = self._block_dims
alignment_x, alignment_y = block_alignment
seq_axis = len(data_shape) - 2
assert (
data_shape[seq_axis] % block_x == 0
), f"Input data of shape {data_shape} should be padded by {block_x} in axes={seq_axis}"
assert (
data_shape[-1] % block_y == 0
), f"Input data of shape {data_shape} should be padded by {block_y} in axis -1"
# NOTE: this overpads if dim > 2 and dims before seq_axis are greater than 1
n_block_seq = data_shape[seq_axis] // block_x
n_block_y = data_shape[-1] // block_y
n_flat_first_dim = reduce(operator.mul, data_shape[:seq_axis], 1) * n_block_seq
# Padding
n_flat_first_dim = ((n_flat_first_dim + alignment_x - 1) // alignment_x) * alignment_x
n_block_y = ((n_block_y + alignment_y - 1) // alignment_y) * alignment_y
out_shape = ()
for i in range(seq_axis):
d = data_shape[i]
out_shape += (d,)
assert n_flat_first_dim % d == 0
n_flat_first_dim //= d
out_shape += (n_flat_first_dim, n_block_y)
return out_shape
# (Phuong: Map the NVTEScalingMode value to the ScalingMode
@dataclass(frozen=True)
@register_pytree_node_class
class ScalingMode(Enum):
"""Enumeration of tensor scaling modes with their corresponding metadata implementations.
This class defines the available scaling modes for tensor quantization:
- NVTE_DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- NVTE_MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- NVTE_INVALID_SCALING: Invalid scaling mode
- NVTE_NO_SCALING: No scaling applied
"""
NVTE_DELAYED_TENSOR_SCALING = 0
NVTE_MXFP8_1D_SCALING = 1
NVTE_INVALID_SCALING = 2
NVTE_NO_SCALING = 3
def _get_impl(self) -> ScalingModeMetadataImpl:
"""Get the implementation for this scaling mode.
Returns:
The scaling mode implementation
Raises:
ValueError: If the scaling mode is invalid
"""
impl = SCALING_MODES_TO_IMPL.get(self)
if impl is None:
raise ValueError("Invalid scaling mode")
return impl
def get_scale_dtype(self):
"""Get the data type for scale tensors in this mode.
Returns:
The data type for scale tensors
"""
return self._get_impl().get_scale_dtype()
def get_scale_shape_2x(self, data_shape, is_padded=True) -> Tuple[Tuple[int]]:
"""Get shapes for both row-wise and column-wise scaling.
Args:
data_shape: Shape of the data tensor
is_padded: Whether to use padded shapes
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
rowwise_scale_shape = self.get_scale_shape(
data_shape, is_colwise=False, is_padded=is_padded
)
colwise_scale_shape = self.get_scale_shape(data_shape, is_colwise=True, is_padded=is_padded)
return (rowwise_scale_shape, colwise_scale_shape)
def get_scale_shape(self, data_shape, is_colwise, is_padded=True) -> Tuple[int]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
Returns:
The shape for scale tensors
"""
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded)
def __eq__(self, other):
"""Compare this scaling mode with another.
Args:
other: The other scaling mode to compare with
Returns:
True if the modes are equal, False otherwise
"""
if not isinstance(other, ScalingMode):
return False
return self.value == other.value
def tree_flatten(self):
"""Flatten this scaling mode for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
return (), (self.value)
@classmethod
def tree_unflatten(cls, aux_data, _children):
"""Reconstruct a scaling mode from its flattened representation.
Args:
aux_data: Auxiliary data containing the mode value
_children: Unused children data
Returns:
A reconstructed ScalingMode instance
"""
return cls(aux_data)
SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(),
ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)),
# WAR
ScalingMode.NVTE_NO_SCALING: DelayedScalingModeMetadataImpl(),
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor classes for TE/JAX
This module provides tensor classes for handling quantized tensors in JAX, including
both single-scale (1x) and double-scale (2x) quantization schemes. It supports
rowwise and colwise quantization modes with proper scaling and dequantization.
"""
from dataclasses import dataclass
from typing import Callable, Tuple
from abc import ABC, abstractmethod
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeAxis
from .scaling_modes import ScalingMode
from .dequantizer import Dequantizer
from ..sharding import (
with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
)
__all__ = [
"ScaledTensor",
"ScaledTensor1x",
"ScaledTensor2x",
"ScaledTensorFactory",
"with_sharding_constraint_by_logical_axes",
]
@register_pytree_node_class
@dataclass
class ScaledTensor(ABC):
"""Abstract base class for scaled tensors.
This class defines the interface for all scaled tensor implementations,
providing methods for dequantization and accessing row/column-wise components.
"""
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Reconstructs the tensor from its flattened representation.
Args:
aux_data: Auxiliary data needed for reconstruction
children: The flattened tensor components
Returns:
A reconstructed tensor instance
"""
return cls(*children, *aux_data)
@abstractmethod
def dequantize(self):
"""Dequantizes the tensor back to its original precision.
Returns:
The dequantized tensor
"""
@abstractmethod
def get_rowwise_tensor(self):
"""Returns the row-wise component of the tensor.
Returns:
The row-wise tensor component
Raises:
ValueError: If called on a tensor that doesn't support row-wise access
"""
@abstractmethod
def get_colwise_tensor(self):
"""Returns the column-wise component of the tensor.
Returns:
The column-wise tensor component
Raises:
ValueError: If called on a tensor that doesn't support column-wise access
"""
@register_pytree_node_class
@dataclass
class ScaledTensor1x(ScaledTensor):
"""Single-scale quantized tensor implementation.
This class represents a tensor quantized with a single scaling factor,
supporting both row-wise and column-wise quantization modes.
Attributes:
data: The quantized tensor data
scale_inv: The inverse scaling factors
scaling_mode: The scaling mode used for quantization
dq_dtype: The data type for dequantized values
_dq_func: The dequantization function
is_colwise: Whether the tensor uses column-wise quantization
layout: The layout specification for the tensor
"""
data: jnp.ndarray
scale_inv: jnp.ndarray
scaling_mode: ScalingMode
dq_dtype: jnp.dtype
_dq_func: Callable
is_colwise: bool
layout: str
def __post_init__(self):
"""Validates and adjusts the scale_inv shape after initialization.
Ensures the scale_inv shape matches the expected shape based on the scaling mode
and quantization direction. Pads the scale_inv if necessary.
"""
expected_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=True
)
expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=False
)
if self.scale_inv.shape != expected_scale_shape:
assert self.scale_inv.shape == expected_unpadded_scale_shape, (
f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded"
f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got"
f" {self.scale_inv.shape}"
)
pad_width = tuple(
(0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape)
)
# This actually pad scale_inv with nan, should we pad it with 127 directly instead?
self.scale_inv = jnp.pad(
self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0
)
def tree_flatten(self):
"""Flattens the tensor for JAX tree operations.
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children = (self.data, self.scale_inv)
aux_data = (self.scaling_mode, self.dq_dtype, self._dq_func, self.is_colwise, self.layout)
return (children, aux_data)
def dequantize(self):
"""Dequantizes the tensor using the stored dequantization function.
Returns:
The dequantized tensor
"""
return self._dq_func(self)
def get_rowwise_tensor(self):
"""Returns the tensor if it's row-wise quantized.
Returns:
The row-wise tensor
Raises:
ValueError: If called on a column-wise quantized tensor
"""
if not self.is_colwise:
return self
raise ValueError("Calling get_rowwise_tensor() from a colwise ScaledTensor1x!")
def get_colwise_tensor(self):
"""Returns the tensor if it's column-wise quantized.
Returns:
The column-wise tensor
Raises:
ValueError: If called on a row-wise quantized tensor
"""
if self.is_colwise:
return self
raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!")
@register_pytree_node_class
@dataclass
class ScaledTensor2x(ScaledTensor):
"""Double-scale quantized tensor implementation.
This class represents a tensor quantized with both row-wise and column-wise scaling factors.
Attributes:
rowwise_tensor: The row-wise quantized component
colwise_tensor: The column-wise quantized component
"""
rowwise_tensor: ScaledTensor1x
colwise_tensor: ScaledTensor1x
def tree_flatten(self):
"""Flattens the tensor for JAX tree operations.
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children = (self.rowwise_tensor, self.colwise_tensor)
aux_data = ()
return (children, aux_data)
def dequantize(self):
"""Dequantizes the tensor using the row-wise component's dequantization.
Returns:
The dequantized tensor
"""
return self.rowwise_tensor.dequantize()
def get_rowwise_tensor(self):
"""Returns the row-wise quantized component.
Returns:
The row-wise tensor component
"""
return self.rowwise_tensor
def get_colwise_tensor(self):
"""Returns the column-wise quantized component.
Returns:
The column-wise tensor component
"""
return self.colwise_tensor
@dataclass
class ScaledTensorFactory:
"""Factory class for creating scaled tensor instances.
Provides static methods to create both single-scale (1x) and double-scale (2x)
quantized tensors with various configurations.
"""
@staticmethod
def create_1x(
data, scale_inv, scaling_mode, dq_dtype=jnp.bfloat16, is_colwise=False, layout="N"
):
"""Creates a single-scale quantized tensor.
Args:
data: The quantized tensor data
scale_inv: The inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
is_colwise: Whether to use column-wise quantization (default: False)
layout: The layout specification (default: "N")
Returns:
A ScaledTensor1x instance
"""
dq_func = Dequantizer.funcs.get(scaling_mode)
return ScaledTensor1x(data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, layout)
@staticmethod
def create_2x(
data,
scale_inv,
colwise_data,
colwise_scale_inv,
scaling_mode,
dq_dtype=jnp.bfloat16,
layout="NN",
):
"""Creates a double-scale quantized tensor.
Args:
data: The row-wise quantized data
scale_inv: The row-wise inverse scaling factors
colwise_data: The column-wise quantized data
colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN")
Returns:
A ScaledTensor2x instance
"""
dq_func = Dequantizer.funcs.get(scaling_mode)
rowwise_tensor = ScaledTensor1x(
data,
scale_inv,
scaling_mode,
dq_dtype,
dq_func,
is_colwise=False,
layout=layout[0],
)
colwise_tensor = ScaledTensor1x(
colwise_data,
colwise_scale_inv,
scaling_mode,
dq_dtype,
dq_func,
is_colwise=True,
layout=layout[1],
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
@staticmethod
def create(
data: jnp.ndarray,
scale_inv: jnp.ndarray,
colwise_data: jnp.ndarray,
colwise_scale_inv: jnp.ndarray,
scaling_mode: ScalingMode,
dq_dtype: jnp.dtype = jnp.bfloat16,
layout: str = "NN",
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE,
):
"""Creates a scaled tensor based on the quantization axis.
Args:
data: The quantized tensor data
scale_inv: The inverse scaling factors
colwise_data: The column-wise quantized data
colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN")
q_axis: The quantization axis (default: ROWWISE)
Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_axis
"""
if q_axis == QuantizeAxis.ROWWISE_COLWISE:
return ScaledTensorFactory.create_2x(
data,
scale_inv,
colwise_data,
colwise_scale_inv,
scaling_mode,
dq_dtype,
layout=layout,
)
is_colwise = q_axis == QuantizeAxis.COLWISE
return ScaledTensorFactory.create_1x(
data, scale_inv, scaling_mode, dq_dtype, is_colwise=is_colwise, layout=layout[0]
)
def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
x: The tensor to apply sharding constraints to
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if isinstance(x, ScaledTensor1x):
return ScaledTensor1x(
data=with_sharding_constraint_by_logical_axes(x.data, logical_axis_names),
scale_inv=x.scale_inv,
scaling_mode=x.scaling_mode,
dq_dtype=x.dq_dtype,
_dq_func=x._dq_func,
is_colwise=x.is_colwise,
layout=x.layout,
)
if isinstance(x, ScaledTensor2x):
return ScaledTensor2x(
rowwise_tensor=with_sharding_constraint_by_logical_axes(
x.rowwise_tensor, logical_axis_names
),
colwise_tensor=with_sharding_constraint_by_logical_axes(
x.colwise_tensor, logical_axis_names
),
)
return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)
...@@ -2,7 +2,22 @@ ...@@ -2,7 +2,22 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Installation script for TE jax extensions.""" """Installation script for Transformer Engine JAX extensions.
This module handles the build and installation of the JAX-specific components
of Transformer Engine. It manages:
- JAX extension compilation with pybind11
- Common header file management
- Build tool dependencies
- Package metadata and dependencies
The script supports both development and release builds, with different
behaviors for:
- Build tool management
- Header file copying
- Extension compilation
- Package distribution
"""
# pylint: disable=wrong-import-position,wrong-import-order # pylint: disable=wrong-import-position,wrong-import-order
...@@ -41,6 +56,34 @@ CMakeBuildExtension = get_build_ext(BuildExtension, True) ...@@ -41,6 +56,34 @@ CMakeBuildExtension = get_build_ext(BuildExtension, True)
if __name__ == "__main__": if __name__ == "__main__":
"""Main entry point for JAX extension installation.
This section handles:
1. Common header file management
- Creates a temporary directory for common headers
- Copies necessary header files from the common library
2. Extension module setup
- Configures the JAX-specific C++ extension
- Sets up build paths and dependencies
3. Package configuration
- Sets package metadata
- Configures build and install requirements
- Sets up extension modules
4. Cleanup
- Removes temporary directories after build
- Cleans up build tools if not in release mode
Environment variables:
- NVTE_RELEASE_BUILD: Controls release build behavior
- NVTE_PROJECT_BUILDING: Set to "1" during build
Note:
The script requires JAX to be installed for building.
It will raise a RuntimeError if JAX is not available.
"""
# Extensions # Extensions
common_headers_dir = "common_headers" common_headers_dir = "common_headers"
copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir)) copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir))
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
""" """Sharding utilities for Transformer Engine in JAX.
Sharding Meta for xmap with CustomCall
This module provides utilities for managing tensor sharding in distributed training,
including support for various parallelism strategies like data parallelism (DP),
tensor parallelism (TP), pipeline parallelism (PP), and full-sharded data
parallelism (FSDP). It includes functions for sharding constraints, mesh management,
and collective operations.
""" """
import os import os
from contextlib import contextmanager from contextlib import contextmanager
...@@ -181,27 +186,17 @@ def get_mesh_axis_rank(axis: str, mesh=None): ...@@ -181,27 +186,17 @@ def get_mesh_axis_rank(axis: str, mesh=None):
@dataclass @dataclass
class MeshResource: class MeshResource:
""" """A data container for managing mesh resources in distributed training.
A data container to indicate which axis in Mesh for data parallelism and
which for tensor parallelism. This class defines the mapping between logical axes and physical mesh axes
for different types of parallelism in distributed training.
Parameters
---------- Attributes:
dp_resource : str, default = None dp_resource: Axis name for data parallelism (batch sharding), default is None
The axis name in Mesh used to shard batches along. tp_resource: Axis name for tensor parallelism (hidden dimension sharding), default is None
If it is None, then data parallelism is disabled. fsdp_resource: Axis name for full-sharded data parallelism, default is None
tp_resource : str, default = None pp_resource: Axis name for pipeline parallelism (layer sharding), default is None
The axis name in Mesh used to split the hidden dimensions along. cp_resource: Axis name for context parallelism (sequence sharding), default is None
If it is None, then tensor parallelism is disabled.
fsdp_resource : str, default = None
The axis name in Mesh used to split the batch and weights along.
If it is None, then full-sharded data parallelism is disabled.
pp_resource : str, default = None
The axis name in Mesh used to split model layers along.
If it is None, then pipeline parallelism is disabled.
cp_resource : str, default = None
The axis name in Mesh used to split sequence (context) dimensions along
in the attention. If it is None, then context parallelism is disabled.
""" """
dp_resource: str = None dp_resource: str = None
...@@ -216,36 +211,55 @@ _GLOBAL_MESH_RESOURCE = MeshResource() ...@@ -216,36 +211,55 @@ _GLOBAL_MESH_RESOURCE = MeshResource()
@contextmanager @contextmanager
def global_shard_guard(resource: MeshResource): def global_shard_guard(resource: MeshResource):
""" """Context manager for setting global sharding configuration.
A context manager to switch the global MeshResource
This context manager allows temporarily setting the global mesh resource
configuration for sharding operations.
Args:
resource: MeshResource instance defining the sharding configuration
""" """
global _GLOBAL_MESH_RESOURCE global _GLOBAL_MESH_RESOURCE
prev_gmr = _GLOBAL_MESH_RESOURCE old_resources = _GLOBAL_MESH_RESOURCE
try: try:
_GLOBAL_MESH_RESOURCE = resource _GLOBAL_MESH_RESOURCE = resource
yield yield
finally: finally:
_GLOBAL_MESH_RESOURCE = prev_gmr _GLOBAL_MESH_RESOURCE = old_resources
def global_mesh_resource() -> MeshResource: def global_mesh_resource() -> MeshResource:
""" """Get the current global mesh resource configuration.
A getter of the global MeshResource
Returns:
The current MeshResource instance
""" """
return _GLOBAL_MESH_RESOURCE return _GLOBAL_MESH_RESOURCE
def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh): def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
""" """Perform all-reduce sum operation along data parallelism and FSDP axes.
All-Reduce (Sum) along DP and FSDP mesh axes.
Args:
x: Input tensor to reduce
mesh: JAX mesh for distributed computation
Returns:
Reduced tensor
""" """
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh) x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh) return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh): def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
""" """Perform all-reduce max operation along all axes except pipeline parallelism.
All-Reduce (Max) along all mesh axes.
Args:
x: Input tensor to reduce
mesh: JAX mesh for distributed computation
Returns:
Reduced tensor
""" """
all_axes = get_all_mesh_axes() all_axes = get_all_mesh_axes()
for axis in all_axes: for axis in all_axes:
...@@ -261,21 +275,16 @@ global_shard_resource = global_mesh_resource ...@@ -261,21 +275,16 @@ global_shard_resource = global_mesh_resource
class MajorShardingType(Enum): class MajorShardingType(Enum):
r""" """Enumeration of major sharding types for distributed training.
The major sharding type to indicate sharding pattern.
.. warning:: This enum defines the basic sharding patterns available for distributed
MajorShardingType is deprecating in the near feature. training. Note that this class is deprecated and will be removed in the future.
Values Values:
---------- SINGLE: Single process training
SINGLE: DP: Data parallel training
Single process training. TP: Standard tensor parallel training
DP: DPTP: Data and standard tensor parallel training
Data parallel training.
TP:
Standard tensor parallel training.
DPTP:
Data and Standard tensor parallel training.
""" """
SINGLE = 0 SINGLE = 0
...@@ -285,25 +294,19 @@ class MajorShardingType(Enum): ...@@ -285,25 +294,19 @@ class MajorShardingType(Enum):
class ShardingType(Enum): class ShardingType(Enum):
""" """Enumeration of detailed sharding types for distributed training.
The sharding type to indicate sharding pattern.
.. warning:: This enum defines specific sharding patterns for distributed training,
ShardingType is deprecating in the near feature. including combinations of data parallelism and different tensor parallelism
strategies. Note that this class is deprecated and will be removed in the future.
Values
---------- Values:
SINGLE: SINGLE: No sharding
No sharding. DP: Sharding along data parallelism
DP: TP_COL: Sharding along column-split tensor parallelism
Sharding along data parallelism. TP_ROW: Sharding along row-split tensor parallelism
TP_COL: DP_TP_COL: Sharding along data and column-split tensor parallelism
Sharding along column-split tensor parallelism. DP_TP_ROW: Sharding along data and row-split tensor parallelism
TP_ROW:
Sharding along row-split tensor parallelism.
DP_TP_COL:
Sharding along data and column-split tensor parallelism.
DP_TP_ROW:
Sharding along data and row-split tensor parallelism.
""" """
SINGLE = (MajorShardingType.SINGLE, "single") SINGLE = (MajorShardingType.SINGLE, "single")
......
...@@ -690,9 +690,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -690,9 +690,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# partial result quantizer # partial result quantizer
for i in range(cp_size): for i in range(cp_size):
S_quantizer_per_step[i] = S_quantizer.copy() S_quantizer_per_step[i] = S_quantizer.copy()
S_quantizer_per_step[i].amax = amax_per_step[0][i] S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,))
O_CP_quantizer_per_step[i] = O_CP_quantizer.copy() O_CP_quantizer_per_step[i] = O_CP_quantizer.copy()
O_CP_quantizer_per_step[i].amax = amax_per_step[1][i] O_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,))
else: else:
assert False, "FP8 is only supported with Fused Attention!" assert False, "FP8 is only supported with Fused Attention!"
else: else:
...@@ -1361,16 +1361,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1361,16 +1361,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if i > 1: if i > 1:
flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done) flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done)
if use_fused_attention:
# [b, np, sq, 1] -> [b, np, sq] or
# [t, np, 1] -> [t, np]
softmax_lse_per_step[i - 1].squeeze_(-1)
if softmax_lse_in_packed_format:
softmax_lse_per_step[i - 1] = (
softmax_lse_per_step[i - 1].transpose(0, 1).contiguous()
)
with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
if use_fused_attention:
# [b, np, sq, 1] -> [b, np, sq] or
# [t, np, 1] -> [t, np]
softmax_lse_per_step[i - 1].squeeze_(-1)
if softmax_lse_in_packed_format:
softmax_lse_per_step[i - 1] = (
softmax_lse_per_step[i - 1].transpose(0, 1).contiguous()
)
if fp8: if fp8:
out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32) out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32)
if i == 1: if i == 1:
...@@ -1479,8 +1478,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1479,8 +1478,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if fp8 and use_fused_attention: if fp8 and use_fused_attention:
amax_cp_fwd = amax_per_step.amax(dim=1) amax_cp_fwd = amax_per_step.amax(dim=1)
S_quantizer.amax = amax_cp_fwd[0] S_quantizer.amax.copy_(amax_cp_fwd[0])
O_CP_quantizer.amax = amax_cp_fwd[1] O_CP_quantizer.amax.copy_(amax_cp_fwd[1])
out_fp8 = None out_fp8 = None
out_f16 = out.to(qkv_dtype) out_f16 = out.to(qkv_dtype)
...@@ -1513,16 +1512,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1513,16 +1512,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.save_for_backward(*tensors_to_save) ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.qkv_dtype = qkv_dtype
ctx.QKV_quantizer = QKV_quantizer
ctx.O_quantizer = O_quantizer
ctx.O_CP_quantizer = O_CP_quantizer
ctx.S_quantizer = S_quantizer
ctx.dQKV_quantizer = dQKV_quantizer
ctx.dQKV_CP_quantizer = dQKV_CP_quantizer
ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer
ctx.cp_group_a2a = cp_group_a2a ctx.cp_group_a2a = cp_group_a2a
ctx.cp_size_a2a = cp_size_a2a ctx.cp_size_a2a = cp_size_a2a
ctx.rank_a2a = rank_a2a ctx.rank_a2a = rank_a2a
...@@ -1546,6 +1535,22 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1546,6 +1535,22 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3 ctx.use_flash_attn_3 = use_flash_attn_3
ctx.qkv_dtype = qkv_dtype
ctx.dQKV_quantizer = dQKV_quantizer
ctx.dQKV_CP_quantizer = dQKV_CP_quantizer
ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer
ctx.QKV_quantizer = QKV_quantizer
ctx.O_quantizer = O_quantizer
ctx.S_quantizer = S_quantizer
if ctx.fp8:
ctx.QKV_quantizer = QKV_quantizer.copy()
ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone()
ctx.O_quantizer = O_quantizer.copy()
ctx.O_quantizer.scale = O_quantizer.scale.clone()
ctx.S_quantizer = S_quantizer.copy()
ctx.S_quantizer.scale = S_quantizer.scale.clone()
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward") nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
return out_ret return out_ret
...@@ -1632,32 +1637,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1632,32 +1637,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.use_fused_attention: if ctx.use_fused_attention:
fused_attn_backend = FusedAttnBackend["FP8"] fused_attn_backend = FusedAttnBackend["FP8"]
dqkv_fp8_torch_dtype = get_fp8_torch_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
dq_fp8 = torch.empty(
(cp_size, *q.shape), dtype=dqkv_fp8_torch_dtype, device=q.device
)
dkv_fp8 = torch.empty(
(cp_size, *kv.shape), dtype=dqkv_fp8_torch_dtype, device=kv.device
)
dkv_fp8_ = torch.empty_like(dkv_fp8)
if ctx.is_output_fp8: if ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
ctx.dO_quantizer = dout._quantizer ctx.dO_quantizer = dout._quantizer
else: else:
dout = ctx.dO_quantizer(dout) dout = ctx.dO_quantizer(dout)
fused_attn_dqkv_dtype = dout._fp8_dtype fused_attn_dqkv_dtype = TE_DType[dout._data.dtype]
dout = dout._data dq_fp8 = torch.empty((cp_size, *q.shape), dtype=dout._data.dtype, device=q.device)
dkv_fp8 = torch.empty(
(cp_size, *kv.shape), dtype=dout._data.dtype, device=kv.device
)
dkv_fp8_ = torch.empty_like(dkv_fp8)
p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]]
dout = dout._data
fp8_meta_kwargs = {} fp8_meta_kwargs = {}
fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
for i in range(cp_size): for i in range(cp_size):
dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() dP_quantizer_per_step[i] = ctx.dP_quantizer.copy()
dP_quantizer_per_step[i].amax = amax_per_step[0][i] dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,))
dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy() dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy()
dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i] dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,))
else: else:
assert False, "FP8 is only supported with Fused Attention!" assert False, "FP8 is only supported with Fused Attention!"
else: else:
...@@ -1838,7 +1838,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1838,7 +1838,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part, v_part,
out_part, out_part,
dout_part, dout_part,
ctx.qkv_dtype, dout_dtype,
fused_attn_dqkv_dtype, fused_attn_dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
fused_attn_backend, fused_attn_backend,
...@@ -1962,7 +1962,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1962,7 +1962,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part, v_part,
out_part, out_part,
dout_part, dout_part,
ctx.qkv_dtype, dout_dtype,
fused_attn_dqkv_dtype, fused_attn_dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
fused_attn_backend, fused_attn_backend,
...@@ -2090,7 +2090,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2090,7 +2090,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part, v_part,
out_part, out_part,
dout_part, dout_part,
ctx.qkv_dtype, dout_dtype,
fused_attn_dqkv_dtype, fused_attn_dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
fused_attn_backend, fused_attn_backend,
...@@ -2195,7 +2195,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2195,7 +2195,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part, v_part,
out_part, out_part,
dout_part, dout_part,
ctx.qkv_dtype, dout_dtype,
fused_attn_dqkv_dtype, fused_attn_dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
fused_attn_backend, fused_attn_backend,
...@@ -2395,8 +2395,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2395,8 +2395,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.fp8 and ctx.use_fused_attention: if ctx.fp8 and ctx.use_fused_attention:
amax_cp_bwd = amax_per_step.amax(dim=1) amax_cp_bwd = amax_per_step.amax(dim=1)
ctx.dP_quantizer.amax = amax_cp_bwd[0] ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0])
ctx.dQKV_CP_quantizer.amax = amax_cp_bwd[1] ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1])
if ctx.qkv_format in ["bshd", "sbhd"]: if ctx.qkv_format in ["bshd", "sbhd"]:
# [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
# [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
...@@ -3229,14 +3229,6 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3229,14 +3229,6 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.save_for_backward(*tensors_to_save) ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.qkv_dtype = qkv_dtype
ctx.QKV_quantizer = QKV_quantizer
ctx.O_quantizer = O_quantizer
ctx.S_quantizer = S_quantizer
ctx.dQKV_quantizer = dQKV_quantizer
ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer
ctx.batch_size = batch_size ctx.batch_size = batch_size
ctx.cp_group = cp_group ctx.cp_group = cp_group
ctx.cp_stream = cp_stream ctx.cp_stream = cp_stream
...@@ -3255,6 +3247,21 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3255,6 +3247,21 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3 ctx.use_flash_attn_3 = use_flash_attn_3
ctx.qkv_dtype = qkv_dtype
ctx.dQKV_quantizer = dQKV_quantizer
ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer
ctx.QKV_quantizer = QKV_quantizer
ctx.O_quantizer = O_quantizer
ctx.S_quantizer = S_quantizer
if ctx.fp8:
ctx.QKV_quantizer = QKV_quantizer.copy()
ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone()
ctx.O_quantizer = O_quantizer.copy()
ctx.O_quantizer.scale = O_quantizer.scale.clone()
ctx.S_quantizer = S_quantizer.copy()
ctx.S_quantizer.scale = S_quantizer.scale.clone()
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
return out_ret return out_ret
...@@ -3291,7 +3298,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3291,7 +3298,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.dO_quantizer = dout._quantizer ctx.dO_quantizer = dout._quantizer
else: else:
dout = ctx.dO_quantizer(dout) dout = ctx.dO_quantizer(dout)
fused_attn_dqkv_dtype = dout._fp8_dtype fused_attn_dqkv_dtype = TE_DType[dout._data.dtype]
dout = dout._data dout = dout._data
fp8_meta_kwargs = {} fp8_meta_kwargs = {}
fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
...@@ -3401,7 +3408,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3401,7 +3408,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
v_part, v_part,
out_part, out_part,
dout_part, dout_part,
ctx.qkv_dtype, dout_dtype,
fused_attn_dqkv_dtype, fused_attn_dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
fused_attn_backend, fused_attn_backend,
...@@ -4748,6 +4755,9 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4748,6 +4755,9 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.dO_quantizer = dO_quantizer ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer ctx.dP_quantizer = dP_quantizer
ctx.S_quantizer = S_quantizer ctx.S_quantizer = S_quantizer
if ctx.fp8:
ctx.S_quantizer = S_quantizer.copy()
ctx.S_quantizer.scale = S_quantizer.scale.clone()
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv ctx.max_seqlen_kv = max_seqlen_kv
...@@ -4963,8 +4973,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4963,8 +4973,6 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
# else, return (dqkv, dbias) # else, return (dqkv, dbias)
return ( return (
...@@ -4995,8 +5003,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4995,8 +5003,6 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -5126,6 +5132,16 @@ class FusedAttention(torch.nn.Module): ...@@ -5126,6 +5132,16 @@ class FusedAttention(torch.nn.Module):
# get q_format and kv_format for training and inference # get q_format and kv_format for training and inference
qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params) qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params)
# cuDNN can work with 0-length sequences in the batch for both bshd/sbhd and thd formats
# however, for bshd/sbhd, q/k/v tensors need to have the same batch size as indicated by
# cu_seqlens, whereas thd does not have this requirement
# e.g. if q_format = bshd, and q.shape = [3, 1, 16, 64], we should have k.shape[0] =
# v.shape[0] = q.shape[0], and cu_seqlens_q.shape = cu_seqlens_kv.shape = [4]
if q_format in ["bshd", "sbhd"] or kv_format in ["bshd", "sbhd"]:
batch_size = query_layer.shape[0] if q_format == "bshd" else query_layer.shape[1]
cu_seqlens_q = cu_seqlens_q[: batch_size + 1]
cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1]
page_table = None page_table = None
if inference_params is None: if inference_params is None:
if qkv_format in ["sbhd", "bshd"]: if qkv_format in ["sbhd", "bshd"]:
...@@ -6209,7 +6225,11 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -6209,7 +6225,11 @@ class DotProductAttention(TransformerEngineBaseModule):
# raise exception if no backend is available # raise exception if no backend is available
if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0: if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0:
raise ValueError("No dot product attention support for the provided inputs!") raise ValueError(
"No dot product attention backend is available for the provided inputs. Please"
" run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to find out the reasons for"
" disabling all backends."
)
# run attention # run attention
if use_flash_attention: if use_flash_attention:
......
...@@ -153,7 +153,6 @@ class Float8CurrentScalingQuantizer : public Quantizer { ...@@ -153,7 +153,6 @@ class Float8CurrentScalingQuantizer : public Quantizer {
DType dtype; DType dtype;
bool with_amax_reduction; bool with_amax_reduction;
c10::intrusive_ptr<dist_group_type> amax_reduction_group; c10::intrusive_ptr<dist_group_type> amax_reduction_group;
int amax_reduction_size;
bool force_pow_2_scales = false; bool force_pow_2_scales = false;
float amax_epsilon = 0.0; float amax_epsilon = 0.0;
......
...@@ -145,24 +145,21 @@ Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& q ...@@ -145,24 +145,21 @@ Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& q
const at::Tensor& scale = quantizer.attr("scale").cast<at::Tensor>(); const at::Tensor& scale = quantizer.attr("scale").cast<at::Tensor>();
const at::Tensor& amax = quantizer.attr("amax").cast<at::Tensor>(); const at::Tensor& amax = quantizer.attr("amax").cast<at::Tensor>();
const DType type = quantizer.attr("dtype").cast<DType>(); const DType type = quantizer.attr("dtype").cast<DType>();
// For current scaling, need several other components:
// 1. with_amax_reduction: bool
// 2. amax_reduction_group: torch.distributed.ProcessGroup or None
// 3. amax_reduction_size: int
const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast<bool>();
const py::object amax_reduction_group_obj = quantizer.attr("amax_reduction_group");
const c10::intrusive_ptr<dist_group_type> amax_reduction_group =
amax_reduction_group_obj.is_none()
? nullptr
: amax_reduction_group_obj.cast<c10::intrusive_ptr<dist_group_type>>();
const int amax_reduction_size = quantizer.attr("amax_reduction_size").cast<int>();
this->amax = amax; this->amax = amax;
this->scale = scale; this->scale = scale;
this->dtype = type; this->dtype = type;
// Get amax reduction group if needed
const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast<bool>();
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
if (with_amax_reduction) {
auto group = quantizer.attr("_canonicalized_amax_reduction_group")();
NVTE_CHECK(!group.is_none(),
"Float8CurrentScalingQuantizer could not canonicalize amax reduction group");
amax_reduction_group = group.cast<c10::intrusive_ptr<dist_group_type>>();
}
this->with_amax_reduction = with_amax_reduction; this->with_amax_reduction = with_amax_reduction;
this->amax_reduction_group = amax_reduction_group; this->amax_reduction_group = amax_reduction_group;
this->amax_reduction_size = amax_reduction_size;
// fp8 current scaling specific quantization params // fp8 current scaling specific quantization params
this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast<bool>(); this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast<bool>();
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""Methods needed for distributed training (DP/TP).""" """Methods needed for distributed training (DP/TP)."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable
from contextlib import contextmanager, AbstractContextManager, ContextDecorator from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from functools import lru_cache from functools import lru_cache
import math import math
...@@ -876,10 +877,14 @@ def _all_gather_fp8( ...@@ -876,10 +877,14 @@ def _all_gather_fp8(
# we cannot directly gather the transposed fp8 tensor # we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer # so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing # and then set it back to the original value after quantizing
init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
inp = quantizer(inp) inp = quantizer(inp)
quantizer.set_usage(columnwise=init_columnwise_usage) quantizer.set_usage(
rowwise=init_rowwise_usage,
columnwise=init_columnwise_usage,
)
# Construct output tensor # Construct output tensor
out: Float8TensorBase out: Float8TensorBase
...@@ -936,9 +941,34 @@ def _all_gather_mxfp8( ...@@ -936,9 +941,34 @@ def _all_gather_mxfp8(
) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]: ) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]:
"""All-gather MXFP8 tensor along first dimension.""" """All-gather MXFP8 tensor along first dimension."""
# Tensor dims # Input tensor attributes
in_shape: Iterable[int]
device: torch.device
dtype: torch.dtype
if isinstance(inp, torch.Tensor):
in_shape = inp.size()
device = inp.device
dtype = inp.dtype
elif isinstance(inp, MXFP8TensorBase):
if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.device.size()
device = inp._rowwise_data.device
dtype = inp._rowwise_data.dtype
elif inp._columnwise_data is not None:
in_shape = inp._columnwise_data.device.size()
device = inp._columnwise_data.device
dtype = inp._columnwise_data.dtype
else:
raise ValueError("Got MXFP8 input tensor without any data")
dtype = torch.bfloat16
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, "
f"found {inp.__class__.__name__})"
)
# Output tensor shape
world_size = get_distributed_world_size(process_group) world_size = get_distributed_world_size(process_group)
in_shape = list(inp.size())
if out_shape is None: if out_shape is None:
out_shape = [in_shape[0] * world_size] + in_shape[1:] out_shape = [in_shape[0] * world_size] + in_shape[1:]
...@@ -951,25 +981,19 @@ def _all_gather_mxfp8( ...@@ -951,25 +981,19 @@ def _all_gather_mxfp8(
): ):
out = torch.empty( out = torch.empty(
out_shape, out_shape,
dtype=inp.dtype, dtype=dtype,
device=inp.device, device=device,
memory_format=torch.contiguous_format, memory_format=torch.contiguous_format,
) )
torch.distributed.all_gather_into_tensor(out, inp, group=process_group) torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out) out = quantizer(out)
return out, None return out, None
inp_dtype = inp.dtype
inp_device = inp.device
# Cast input tensor to MXFP8 with required data # Cast input tensor to MXFP8 with required data
if not isinstance(inp, MXFP8TensorBase): if not isinstance(inp, MXFP8TensorBase):
inp = quantizer(inp) inp = quantizer(inp)
elif ( elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
inp.rowwise_data is None quantizer.columnwise_usage and inp._columnwise_data is None
and quantizer.rowwise_usage
or inp.columnwise_data is None
and quantizer.columnwise_usage
): ):
warnings.warn( warnings.warn(
"Input and quantizer do not have matching usages. " "Input and quantizer do not have matching usages. "
...@@ -978,65 +1002,64 @@ def _all_gather_mxfp8( ...@@ -978,65 +1002,64 @@ def _all_gather_mxfp8(
inp = quantizer(inp.dequantize()) inp = quantizer(inp.dequantize())
# Construct MXFP8 output tensor # Construct MXFP8 output tensor
out = quantizer.make_empty(out_shape, dtype=inp_dtype, device=inp_device) out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
# Async op handle
handle = None
# Gather MXFP8 data for row-wise usage
if quantizer.rowwise_usage:
# Remove padding from MXFP8 scale-inverses
in_scale_inv = inp._rowwise_scale_inv
out_scale_inv = out._rowwise_scale_inv
flattened_in_shape0 = math.prod(in_shape[:-1])
if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv[flattened_in_shape0 * world_size :].zero_()
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers
if handle is not None:
handle.wait()
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
handle = torch.distributed.all_gather_into_tensor(
out._rowwise_data,
inp._rowwise_data,
group=process_group,
async_op=async_op,
)
# Gather MXFP8 data for column-wise usage
if quantizer.columnwise_usage:
# Remove padding from MXFP8 scale-inverses # Coalesce NCCL collectives
in_scale_inv = inp._columnwise_scale_inv with torch.distributed._coalescing_manager(
out_scale_inv = out._columnwise_scale_inv group=process_group,
flattened_in_shape0 = math.prod(in_shape[:-1]) // 32 device=device,
if in_scale_inv.size(0) != flattened_in_shape0: async_ops=async_op,
in_scale_inv = in_scale_inv[:flattened_in_shape0] ) as coalescing_manager:
out_scale_inv[flattened_in_shape0 * world_size :].zero_()
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] # Gather MXFP8 data for row-wise usage
if quantizer.rowwise_usage:
# Remove padding from MXFP8 scale-inverses
in_scale_inv = inp._rowwise_scale_inv
out_scale_inv = out._rowwise_scale_inv
flattened_in_shape0 = math.prod(in_shape[:-1])
if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv[flattened_in_shape0 * world_size :].zero_()
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._rowwise_data,
inp._rowwise_data,
group=process_group,
)
# Launch all-gathers # Gather MXFP8 data for column-wise usage
if handle is not None: if quantizer.columnwise_usage:
handle.wait()
torch.distributed.all_gather_into_tensor( # Remove padding from MXFP8 scale-inverses
out_scale_inv, in_scale_inv = inp._columnwise_scale_inv
in_scale_inv, out_scale_inv = out._columnwise_scale_inv
group=process_group, flattened_in_shape0 = math.prod(in_shape[:-1]) // 32
) if in_scale_inv.size(0) != flattened_in_shape0:
handle = torch.distributed.all_gather_into_tensor( in_scale_inv = in_scale_inv[:flattened_in_shape0]
out._columnwise_data, out_scale_inv[flattened_in_shape0 * world_size :].zero_()
inp._columnwise_data, out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
group=process_group,
async_op=async_op, # Launch all-gathers
) torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._columnwise_data,
inp._columnwise_data,
group=process_group,
)
handle = coalescing_manager if async_op else None
return out, handle return out, handle
......
...@@ -100,7 +100,7 @@ class InferenceParams: ...@@ -100,7 +100,7 @@ class InferenceParams:
---------- ----------
max_batch_size: int max_batch_size: int
Maximum batch size in inference Maximum batch size in inference
max_seqlen_kv: int max_sequence_length: int
Maximum sequence length in inference Maximum sequence length in inference
num_heads_kv: int num_heads_kv: int
Number of attention heads in keys and values Number of attention heads in keys and values
...@@ -117,7 +117,7 @@ class InferenceParams: ...@@ -117,7 +117,7 @@ class InferenceParams:
page_size: int, default = None page_size: int, default = None
Page size of the KV cache. Required for is_paged = True. Page size of the KV cache. Required for is_paged = True.
max_ctx_len: int, default = None max_ctx_len: int, default = None
Maximum context length in inference. 1 <= max_ctx_len <= max_seqlen_kv. Maximum context length in inference. 1 <= max_ctx_len <= max_sequence_length.
qkv_format: str, default = "bshd" qkv_format: str, default = "bshd"
Format of the incoming query/key/value tensors in current iteration Format of the incoming query/key/value tensors in current iteration
custom_cache_manager: KVCacheManager, default = None custom_cache_manager: KVCacheManager, default = None
...@@ -127,7 +127,7 @@ class InferenceParams: ...@@ -127,7 +127,7 @@ class InferenceParams:
def __init__( def __init__(
self, self,
max_batch_size: int, max_batch_size: int,
max_seqlen_kv: int, max_sequence_length: int,
num_heads_kv: int = 16, num_heads_kv: int = 16,
head_dim_k: int = 64, head_dim_k: int = 64,
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
...@@ -140,7 +140,7 @@ class InferenceParams: ...@@ -140,7 +140,7 @@ class InferenceParams:
custom_cache_manager: KVCacheManager = None, custom_cache_manager: KVCacheManager = None,
): ):
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_seqlen_kv = max_seqlen_kv self.max_sequence_length = max_sequence_length
self.num_heads_kv = num_heads_kv self.num_heads_kv = num_heads_kv
self.head_dim_k = head_dim_k self.head_dim_k = head_dim_k
self.dtype = dtype self.dtype = dtype
...@@ -153,7 +153,7 @@ class InferenceParams: ...@@ -153,7 +153,7 @@ class InferenceParams:
) )
self.cache_manager = cache_manager( self.cache_manager = cache_manager(
max_batch_size=self.max_batch_size, max_batch_size=self.max_batch_size,
max_seqlen=self.max_seqlen_kv, max_seqlen=self.max_sequence_length,
num_heads=self.num_heads_kv, num_heads=self.num_heads_kv,
head_dim_k=self.head_dim_k, head_dim_k=self.head_dim_k,
dtype=self.dtype, dtype=self.dtype,
...@@ -163,9 +163,9 @@ class InferenceParams: ...@@ -163,9 +163,9 @@ class InferenceParams:
assert page_size is not None, "Paged KV cache requires page_size is not None." assert page_size is not None, "Paged KV cache requires page_size is not None."
self.page_size = page_size self.page_size = page_size
assert ( assert (
max_seqlen_kv % page_size == 0 max_sequence_length % page_size == 0
), "Paged KV cache requires max_seqlen_kv % page_size = 0." ), "Paged KV cache requires max_sequence_length % page_size = 0."
max_pages_per_seq = max_seqlen_kv // page_size max_pages_per_seq = max_sequence_length // page_size
assert ( assert (
total_num_pages == self.max_batch_size * max_pages_per_seq total_num_pages == self.max_batch_size * max_pages_per_seq
), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq." ), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq."
...@@ -181,7 +181,7 @@ class InferenceParams: ...@@ -181,7 +181,7 @@ class InferenceParams:
head_dim_k=self.head_dim_k, head_dim_k=self.head_dim_k,
dtype=self.dtype, dtype=self.dtype,
max_batch_size=self.max_batch_size, max_batch_size=self.max_batch_size,
max_seqlen=self.max_seqlen_kv, max_seqlen=self.max_sequence_length,
head_dim_v=self.head_dim_v, head_dim_v=self.head_dim_v,
) )
...@@ -231,7 +231,7 @@ class InferenceParams: ...@@ -231,7 +231,7 @@ class InferenceParams:
f"dtype={self.dtype}, " f"dtype={self.dtype}, "
f"is_paged={self.is_paged}, " f"is_paged={self.is_paged}, "
f"max_batch_size={self.max_batch_size}, " f"max_batch_size={self.max_batch_size}, "
f"max_seqlen={self.max_seqlen_kv}, " f"max_seqlen={self.max_sequence_length}, "
f"num_heads={self.num_heads_kv}, " f"num_heads={self.num_heads_kv}, "
f"head_dim_k={self.head_dim_k}, " f"head_dim_k={self.head_dim_k}, "
f"head_dim_v={self.head_dim_v}" f"head_dim_v={self.head_dim_v}"
...@@ -241,8 +241,8 @@ class InferenceParams: ...@@ -241,8 +241,8 @@ class InferenceParams:
""" """
Allocate memory for the cache. For layer layer_number, Allocate memory for the cache. For layer layer_number,
- NonPagedKVCacheManager: - NonPagedKVCacheManager:
- K cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_k] - K cache: [max_batch_size, max_sequence_length, num_heads_kv, head_dim_k]
- V cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_v] - V cache: [max_batch_size, max_sequence_length, num_heads_kv, head_dim_v]
- PagedKVCacheManager: - PagedKVCacheManager:
- K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k] - K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k]
- V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v] - V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v]
...@@ -348,7 +348,7 @@ class InferenceParams: ...@@ -348,7 +348,7 @@ class InferenceParams:
Updated cumulative sequence lengths for key and value, [batch_size + 1] Updated cumulative sequence lengths for key and value, [batch_size + 1]
max_seqlen_q: int max_seqlen_q: int
Update maximum sequence length for query Update maximum sequence length for query
max_seqlen_kv: int max_sequence_length: int
Update maximum sequence length for key and value Update maximum sequence length for key and value
qkv_format: str qkv_format: str
Updated qkv_format, e.g. 'thd' format becomes 'thd_2bshd' after step() Updated qkv_format, e.g. 'thd' format becomes 'thd_2bshd' after step()
...@@ -373,7 +373,7 @@ class InferenceParams: ...@@ -373,7 +373,7 @@ class InferenceParams:
v_cache, v_cache,
self.cu_seqlens_q, self.cu_seqlens_q,
self.cu_seqlens_kv, self.cu_seqlens_kv,
self.max_seqlen_kv, self.max_sequence_length,
self.output_qkv_format, self.output_qkv_format,
) )
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""NVFuser functions and JIT utilities""" """NVFuser functions and JIT utilities"""
import os import os
from functools import wraps
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
import torch import torch
...@@ -11,15 +12,34 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION ...@@ -11,15 +12,34 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
# pylint: disable=unnecessary-lambda-assignment # pylint: disable=unnecessary-lambda-assignment
def lazy_compile(func):
"""Lazy compile a function with torch.compile
This decorator defers the compilation of a function until the first call, speeding up the
overall module's import time if these functions are not used.
"""
compiled_func = None
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal compiled_func
if compiled_func is None:
compiled_func = torch.compile(func)
return compiled_func(*args, **kwargs)
return wrapper
jit_fuser = lambda func: func jit_fuser = lambda func: func
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile jit_fuser = lazy_compile
# See: https://github.com/NVIDIA/TransformerEngine/issues/597 # See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script dropout_fuser = torch.jit.script
if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
dropout_fuser = torch.compile dropout_fuser = lazy_compile
# Decorator to disable Torch Dynamo # Decorator to disable Torch Dynamo
......
...@@ -1018,6 +1018,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1018,6 +1018,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out = None out = None
if cache_name is not None: if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None) out = self._fp8_workspaces.get(cache_name, None)
if quantizer is not None and isinstance(out, MXFP8TensorBase):
if quantizer.rowwise_usage and out._rowwise_data is None:
out = None
del self._fp8_workspaces[cache_name]
elif quantizer.columnwise_usage and out._columnwise_data is None:
out = None
del self._fp8_workspaces[cache_name]
# Gather cached Fp8 workspace if it's distributed # Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
......
...@@ -78,8 +78,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -78,8 +78,8 @@ class _GroupedLinear(torch.autograd.Function):
skip_fp8_weight_update, skip_fp8_weight_update,
*weights_and_biases, *weights_and_biases,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
num_gemms = len(m_splits) num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms] weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:] biases = weights_and_biases[num_gemms:]
...@@ -180,7 +180,12 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -180,7 +180,12 @@ class _GroupedLinear(torch.autograd.Function):
ctx.weights_shape_1 = weights[0].shape[1] ctx.weights_shape_1 = weights[0].shape[1]
tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases) tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats,
*weights_fp8,
*weights,
*biases,
)
ctx.save_for_backward(*tensors_to_save) ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
...@@ -220,7 +225,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -220,7 +225,8 @@ class _GroupedLinear(torch.autograd.Function):
N = ctx.num_gemms N = ctx.num_gemms
inputmats = saved_tensors[:N] inputmats = saved_tensors[:N]
weights = saved_tensors[N : 2 * N] weights = saved_tensors[N : 2 * N]
biases = saved_tensors[2 * N : 3 * N] origin_weights = saved_tensors[2 * N : 3 * N]
biases = saved_tensors[3 * N : 4 * N]
main_grads = ctx.main_grads main_grads = ctx.main_grads
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO
...@@ -311,21 +317,24 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -311,21 +317,24 @@ class _GroupedLinear(torch.autograd.Function):
# Deallocate input tensor # Deallocate input tensor
clear_tensor_data(*inputmats) clear_tensor_data(*inputmats)
def handle_custom_ddp_from_mcore(w, wgrad): def handle_custom_ddp_from_mcore(weight, wgrad):
if ctx.weights_requires_grad: if ctx.weights_requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"): # Handle custom DDP from mcore.
w.grad_added_to_main_grad = True if ctx.fuse_wgrad_accumulation and hasattr(
if getattr(w, "zero_out_wgrad", False): weight, "grad_added_to_main_grad"
):
weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False):
wgrad = torch.zeros( wgrad = torch.zeros(
w.main_grad.shape, weight.main_grad.shape,
dtype=w.dtype, dtype=weight.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False, requires_grad=False,
) )
else: else:
wgrad = torch.empty( wgrad = torch.empty(
w.main_grad.shape, weight.main_grad.shape,
dtype=w.dtype, dtype=weight.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False, requires_grad=False,
) )
...@@ -336,7 +345,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -336,7 +345,8 @@ class _GroupedLinear(torch.autograd.Function):
return wgrad return wgrad
wgrad_list = [ wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list) handle_custom_ddp_from_mcore(weight, wgrad)
for weight, wgrad in zip(origin_weights, wgrad_list)
] ]
else: else:
wgrad_list = [None] * ctx.num_gemms wgrad_list = [None] * ctx.num_gemms
......
...@@ -55,7 +55,6 @@ from ..tensor.quantized_tensor import ( ...@@ -55,7 +55,6 @@ from ..tensor.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
...@@ -137,6 +136,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -137,6 +136,11 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast") nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
# Avoid quantized norm kernel if norm output will be returned
with_quantized_norm = (
fp8 and not return_layernorm_output and not return_layernorm_output_gathered
)
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag_fprop = ( ub_overlap_ag_fprop = (
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
...@@ -146,6 +150,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -146,6 +150,7 @@ class _LayerNormLinear(torch.autograd.Function):
backward_needs_input = is_grad_enabled and weight_requires_grad backward_needs_input = is_grad_enabled and weight_requires_grad
with_input_all_gather = parallel_mode == "column" and sequence_parallel with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Check if Userbuffers is supported
if fp8: if fp8:
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
...@@ -155,104 +160,74 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -155,104 +160,74 @@ class _LayerNormLinear(torch.autograd.Function):
" current scaling" " current scaling"
) )
# Configure quantizer for norm output
if fp8:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
columnwise_usage = backward_needs_input
# Configure quantizer for normalization output if (
with_quantized_norm = fp8 and not return_layernorm_output columnwise_usage
if with_quantized_norm: and with_input_all_gather
if with_input_all_gather: and not isinstance(input_quantizer, MXFP8Quantizer)
input_quantizer.set_usage(rowwise=True, columnwise=False) ):
if isinstance(input_quantizer, MXFP8Quantizer): columnwise_usage = False
with_quantized_norm = False input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
else:
input_quantizer.set_usage(
rowwise=True,
columnwise=backward_needs_input,
)
# Reduce duplicated transpose in `_fix_gathered_fp8_transpose`
if (
fp8
and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
and ub_bulk_dgrad
):
input_quantizer.set_usage(rowwise=True, columnwise=False)
ub_obj_fprop = None
ln_out = None
# For DelayScaling, output of normalization will be in fp8.
# For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8.
if ub_overlap_ag_fprop and not isinstance(input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True)
elif with_quantized_norm:
if with_input_all_gather:
input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda")
else:
ln_out = torch.empty_like(
inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format, device="cuda"
)
# Apply normalization # Apply normalization
nvtx_range_push(f"{nvtx_label}.norm") nvtx_range_push(f"{nvtx_label}.norm")
ln_out, mu, rsigma = apply_normalization( ln_out, mu, rsigma = apply_normalization(
inputmat, inputmat,
ln_out, None, # ln_out
ln_weight, ln_weight,
ln_bias, ln_bias,
eps, eps,
input_quantizer if with_quantized_norm else None, input_quantizer if with_quantized_norm else None,
inp.dtype, inputmat.dtype,
normalization, normalization,
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
) )
ln_out_return = ln_out if return_layernorm_output else None ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out
nvtx_range_pop(f"{nvtx_label}.norm") nvtx_range_pop(f"{nvtx_label}.norm")
# For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer.
# So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer.
if ub_overlap_ag_fprop and isinstance(input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_fprop = get_ub(ub_name + "_fprop")
ln_out_local = ln_out
ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True)
input_quantizer.quantize(ln_out_local, out=ln_out)
# Prepare GEMM input # Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm") nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm")
if with_input_all_gather and not ub_overlap_ag_fprop: ln_out_total = None
with_quantized_all_gather = fp8 ub_obj_fprop = None
if return_layernorm_output and return_layernorm_output_gathered: if with_input_all_gather:
with_quantized_all_gather = False if return_layernorm_output_gathered:
if fp8: # Perform all-gather in high precision if gathered
input_quantizer.set_usage(rowwise=True, columnwise=False) # norm output will be returned
# ln_out in this has two possibilities: ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
# 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel
# 2. in high precision, then we need to cast it and then gather in FP8
# the output ln_out_total will be in FP8, and it's a full tensor
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(input_quantizer if with_quantized_all_gather else None),
)
if return_layernorm_output and return_layernorm_output_gathered:
ln_out_return = ln_out_total ln_out_return = ln_out_total
if fp8 and not with_quantized_all_gather: if fp8:
ln_out_total = input_quantizer(ln_out_total) ln_out = input_quantizer(ln_out)
else: input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: ln_out_total = input_quantizer(ln_out_total)
ln_out_total = ub_obj_fprop.get_buffer(input_quantizer)
else: else:
if fp8: if fp8:
if not isinstance(ln_out, QuantizedTensor): if not with_quantized_norm:
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
ln_out = input_quantizer(ln_out) ln_out = input_quantizer(ln_out)
elif backward_needs_input: input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out.update_usage(rowwise_usage=True, columnwise_usage=True) if ub_overlap_ag_fprop:
ln_out_total = ln_out # Copy into Userbuffers buffer
ub_obj_fprop = get_ub(ub_name + "_fprop")
ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True).copy_(ln_out)
ln_out_total = ub_obj_fprop.get_buffer(input_quantizer)
else:
# All-gather with NCCL
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(input_quantizer if fp8 else None),
)
else:
if fp8 and not with_quantized_norm:
ln_out = input_quantizer(ln_out)
ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
# Cast weight to expected dtype # Cast weight to expected dtype
...@@ -341,7 +316,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -341,7 +316,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight.requires_grad and parallel_mode == "column" and sequence_parallel weight.requires_grad and parallel_mode == "column" and sequence_parallel
) )
# Input with column-wise usage is needed for dgrad GEMM. # Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input: if backward_needs_input:
if isinstance(ln_out, QuantizedTensor): if isinstance(ln_out, QuantizedTensor):
# For sequence parallel in vanilla FP8, rowwise data is # For sequence parallel in vanilla FP8, rowwise data is
...@@ -350,6 +325,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -350,6 +325,11 @@ class _LayerNormLinear(torch.autograd.Function):
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
ln_out.update_usage(rowwise_usage=False) ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensor):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading: if cpu_offloading:
if fp8 and weightmat is not None: if fp8 and weightmat is not None:
set_offloading_param(weightmat, "weight_offloading", True) set_offloading_param(weightmat, "weight_offloading", True)
...@@ -392,7 +372,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -392,7 +372,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight, weight,
bias, bias,
ln_weight, ln_weight,
ln_out.clone() if ub_overlap_ag_fprop else ln_out, # avoid saving a UB buffer ln_out,
mu, mu,
rsigma, rsigma,
) )
...@@ -603,7 +583,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -603,7 +583,7 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer = None quantizer = None
if ctx.fp8: if ctx.fp8:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out, ln_out,
...@@ -1436,9 +1416,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1436,9 +1416,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.quantizers["scaling_fwd"][ self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_size = self.tp_size
else: else:
# set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here) # set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here)
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
......
...@@ -61,7 +61,6 @@ from ..tensor.float8_tensor import Float8Tensor ...@@ -61,7 +61,6 @@ from ..tensor.float8_tensor import Float8Tensor
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose from ._common import apply_normalization, _fix_gathered_fp8_transpose
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
Quantizer, Quantizer,
...@@ -208,112 +207,81 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -208,112 +207,81 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None: if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
# for fp8 DelayedScaling: layernorm output = FP8 # Avoid quantized norm kernel if norm output will be returned
# only output of the linear is returned with_quantized_norm = (
# for return_layernorm_output: layernorm output = High precision, then cast to FP8 fp8 and not return_layernorm_output and not return_layernorm_output_gathered
# high precision layernorm output and output of the linear are returned )
with_quantized_norm = fp8 and not return_layernorm_output
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled ub_overlap_rs = ub_overlap_rs and is_grad_enabled
with_input_all_gather_nccl = sequence_parallel and not ub_overlap_ag
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
# Configure quantizer for normalization output # Configure quantizer for norm output
if fp8 and fc1_input_quantizer is None: if fp8:
raise ValueError("Missing quantizer for input tensor") if fc1_input_quantizer is None:
if with_quantized_norm: raise ValueError("Missing quantizer for FC1 input tensor")
if with_input_all_gather_nccl: columnwise_usage = backwards_needs_fc1_input
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if (
if isinstance(fc1_input_quantizer, MXFP8Quantizer): columnwise_usage
with_quantized_norm = False and sequence_parallel
else: and not isinstance(fc1_input_quantizer, MXFP8Quantizer)
fc1_input_quantizer.set_usage( ):
rowwise=True, columnwise_usage = False
columnwise=backwards_needs_fc1_input, fc1_input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
)
# Reduce duplicated transpose in `_fix_gathered_fp8_transpose`
if (
fp8
and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
and ub_bulk_dgrad
):
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ub_obj_lnout = None
ln_out = None
# For DelayScaling, output of normalization will be in fp8.
# For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8.
if ub_overlap_ag and not isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True)
elif not with_quantized_norm:
ln_out = torch.empty_like(
inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format, device="cuda"
)
# Apply normalization # Apply normalization
ln_out, mu, rsigma = apply_normalization( ln_out, mu, rsigma = apply_normalization(
inputmat, inputmat,
ln_out, None, # ln_out
ln_weight, ln_weight,
ln_bias, ln_bias,
eps, eps,
fc1_input_quantizer if with_quantized_norm else None, fc1_input_quantizer if with_quantized_norm else None,
inp.dtype, inputmat.dtype,
normalization, normalization,
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
) )
ln_out_return = None
ln_out_return = ln_out if return_layernorm_output else None if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out
# For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer.
# So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer.
if ub_overlap_ag and isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer):
ub_obj_lnout = get_ub("fc1_fprop")
ln_out_local = ln_out
ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True)
fc1_input_quantizer.quantize(ln_out_local, out=ln_out)
# Prepare GEMM input # Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
ln_out_gathered = False ln_out_total = None
with_quantized_all_gather = fp8 ub_obj_lnout = None
if with_input_all_gather_nccl: if sequence_parallel:
if return_layernorm_output and return_layernorm_output_gathered: if return_layernorm_output_gathered:
with_quantized_all_gather = False # Perform all-gather in high precision if gathered
if fp8: # norm output will be returned
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
# ln_out in this has two possibilities: ln_out_return = ln_out_total
# 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel if fp8:
# 2. in high precision, then we need to cast it and then gather in FP8 ln_out = fc1_input_quantizer(ln_out)
# the output ln_out_total will be in FP8, and it's a full tensor fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total, _ = gather_along_first_dim( ln_out_total = fc1_input_quantizer(ln_out_total)
ln_out,
tp_group,
quantizer=(fc1_input_quantizer if with_quantized_all_gather else None),
)
ln_out_gathered = True
else:
with_quantized_all_gather = False
if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer, False)
else: else:
if fp8: if fp8:
if not isinstance(ln_out, QuantizedTensor): if not with_quantized_norm:
fc1_input_quantizer.set_usage(
rowwise=True, columnwise=backwards_needs_fc1_input
)
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
elif backwards_needs_fc1_input: fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out.update_usage(rowwise_usage=True, columnwise_usage=True) if ub_overlap_ag:
# here ln_out is in FP8 low precision, the cast was either done by fc1_input_quantizer # Copy into Userbuffers buffer
# or fused into the layernorm kernel ub_obj_lnout = get_ub("fc1_fprop")
# ln_out_total represents the full fp8 tensor, in this case, it's the same as ln_out ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True).copy_(ln_out)
ln_out_total = ln_out ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer)
else:
# All-gather with NCCL
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(fc1_input_quantizer if fp8 else None),
)
else:
if fp8 and not with_quantized_norm:
ln_out = fc1_input_quantizer(ln_out)
ln_out_total = ln_out
# Cast weights to expected dtype # Cast weights to expected dtype
if not fp8: if not fp8:
...@@ -423,7 +391,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -423,7 +391,6 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size[0] = dim_size[0] // tp_world_size dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
fc2_out = ub_obj_fc2out.get_buffer(output_quantizer)
else: else:
dim_size = list(act_out.size()) dim_size = list(act_out.size())
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
...@@ -443,6 +410,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -443,6 +410,14 @@ class _LayerNormMLP(torch.autograd.Function):
ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None, ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None,
extra_output=rs_out, extra_output=rs_out,
) )
# Weight with column-wise usage is needed for dgrad GEMM.
if is_grad_enabled and inp.requires_grad:
if isinstance(fc1_weight_final, QuantizedTensor):
fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensor):
fc2_weight_final.update_usage(columnwise_usage=True)
if not is_grad_enabled: if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
else: else:
...@@ -490,13 +465,15 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -490,13 +465,15 @@ class _LayerNormMLP(torch.autograd.Function):
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
inputmat, inputmat,
ln_weight, ln_weight,
ln_out.clone() if ub_overlap_ag else ln_out, # avoid saving a UB buffer ln_out,
fc1_weight_final, fc1_weight_final,
fc1_weight,
fc1_bias, fc1_bias,
fc1_out, fc1_out,
fc1_out_without_bias, fc1_out_without_bias,
act_out, act_out,
fc2_weight_final, fc2_weight_final,
fc2_weight,
fc2_bias, fc2_bias,
mu, mu,
rsigma, rsigma,
...@@ -537,7 +514,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -537,7 +514,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.bias_gelu_fusion = bias_gelu_fusion ctx.bias_gelu_fusion = bias_gelu_fusion
ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output = return_layernorm_output
ctx.return_layernorm_output_gathered = ( ctx.return_layernorm_output_gathered = (
return_layernorm_output_gathered and ln_out_gathered return_layernorm_output_gathered and sequence_parallel
) )
ctx.set_parallel_mode = set_parallel_mode ctx.set_parallel_mode = set_parallel_mode
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
...@@ -609,11 +586,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -609,11 +586,13 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight, ln_weight,
ln_out, ln_out,
fc1_weight, fc1_weight,
origin_fc1_weight,
fc1_bias, fc1_bias,
fc1_out, fc1_out,
fc1_out_without_bias, fc1_out_without_bias,
act_out, act_out,
fc2_weight, fc2_weight,
origin_fc2_weight,
fc2_bias, fc2_bias,
mu, mu,
rsigma, rsigma,
...@@ -632,7 +611,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -632,7 +611,7 @@ class _LayerNormMLP(torch.autograd.Function):
) )
fc2_weight_main_grad = ( fc2_weight_main_grad = (
ctx.fc2_main_grad ctx.fc2_main_grad
if fc2_weight is not None if origin_fc2_weight is not None
and ctx.fuse_wgrad_accumulation and ctx.fuse_wgrad_accumulation
and ctx.fc2_weight_requires_grad and ctx.fc2_weight_requires_grad
else None else None
...@@ -641,8 +620,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -641,8 +620,8 @@ class _LayerNormMLP(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one. # we need to connect them into one.
if ctx.fuse_wgrad_accumulation: if ctx.fuse_wgrad_accumulation:
fc1_weight.main_grad = fc1_weight_main_grad origin_fc1_weight.main_grad = fc1_weight_main_grad
fc2_weight.main_grad = fc2_weight_main_grad origin_fc2_weight.main_grad = fc2_weight_main_grad
# TODO: Fix this # pylint: disable=fixme # TODO: Fix this # pylint: disable=fixme
# Gather saved autograd context tensors when running with FSDP # Gather saved autograd context tensors when running with FSDP
...@@ -697,7 +676,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -697,7 +676,7 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer = None quantizer = None
if ctx.fp8: if ctx.fp8:
quantizer = ctx.fc1_input_quantizer quantizer = ctx.fc1_input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out, ln_out,
ctx.tp_group, ctx.tp_group,
...@@ -759,14 +738,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -759,14 +738,18 @@ class _LayerNormMLP(torch.autograd.Function):
act_out, act_out,
grad_output, grad_output,
get_workspace(), get_workspace(),
out_dtype=ctx.activation_dtype, out_dtype=(
origin_fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
quantization_params=None, # wgrad in high precision quantization_params=None, # wgrad in high precision
layout="NT", layout="NT",
grad=True, grad=True,
bias=fc2_bias if fc2_bias_grad is None else None, bias=fc2_bias if fc2_bias_grad is None else None,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
if fc2_bias_grad is None: if fc2_bias_grad is None:
fc2_bias_grad = fc2_bias_grad_ fc2_bias_grad = fc2_bias_grad_
...@@ -919,12 +902,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -919,12 +902,16 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total, ln_out_total,
dact, dact,
get_workspace(), get_workspace(),
out_dtype=ctx.activation_dtype, out_dtype=(
origin_fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
layout="NT", layout="NT",
grad=fuse_gemm_and_bias_fc1_wgrad, grad=fuse_gemm_and_bias_fc1_wgrad,
bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub=ub_obj_fc1_wgrad, ub=ub_obj_fc1_wgrad,
ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None, ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None,
extra_output=fc1_dgrad_rs_out, extra_output=fc1_dgrad_rs_out,
...@@ -985,16 +972,21 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -985,16 +972,21 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fc1_weight_requires_grad: if ctx.fc1_weight_requires_grad:
# Handle custom DDP from mcore. # Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, "grad_added_to_main_grad"): if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, "grad_added_to_main_grad"):
fc1_weight.grad_added_to_main_grad = True origin_fc1_weight.grad_added_to_main_grad = True
if getattr(fc1_weight, "zero_out_wgrad", False): if getattr(origin_fc1_weight, "zero_out_wgrad", False):
fc1_wgrad = torch.zeros( fc1_wgrad = torch.zeros(
fc1_weight.main_grad.shape, origin_fc1_weight.main_grad.shape,
dtype=fc1_weight.dtype, dtype=origin_fc1_weight.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False, requires_grad=False,
) )
else: else:
fc1_wgrad = None fc1_wgrad = torch.empty(
origin_fc1_weight.main_grad.shape,
dtype=origin_fc1_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation: elif ctx.fuse_wgrad_accumulation:
fc1_wgrad = None fc1_wgrad = None
else: else:
...@@ -1002,17 +994,24 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1002,17 +994,24 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fc2_weight_requires_grad: if ctx.fc2_weight_requires_grad:
# Handle custom DDP from mcore. # Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, "grad_added_to_main_grad"): if ctx.fuse_wgrad_accumulation and hasattr(
fc2_weight.grad_added_to_main_grad = True origin_fc2_weight, "grad_added_to_main_grad"
if getattr(fc2_weight, "zero_out_wgrad", False): ):
origin_fc2_weight.grad_added_to_main_grad = True
if getattr(origin_fc2_weight, "zero_out_wgrad", False):
fc2_wgrad = torch.zeros( fc2_wgrad = torch.zeros(
fc2_weight.main_grad.shape, origin_fc2_weight.main_grad.shape,
dtype=fc2_weight.dtype, dtype=origin_fc2_weight.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False, requires_grad=False,
) )
else: else:
fc2_wgrad = None fc2_wgrad = torch.empty(
origin_fc2_weight.main_grad.shape,
dtype=origin_fc2_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation: elif ctx.fuse_wgrad_accumulation:
fc2_wgrad = None fc2_wgrad = None
else: else:
...@@ -1602,9 +1601,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1602,9 +1601,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.quantizers["scaling_fwd"][ self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_size = self.tp_size
else: else:
# grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer # grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
...@@ -1628,6 +1624,3 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1628,6 +1624,3 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_size = self.tp_size
...@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import ( ...@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
...@@ -140,9 +141,13 @@ class _Linear(torch.autograd.Function): ...@@ -140,9 +141,13 @@ class _Linear(torch.autograd.Function):
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
if with_input_all_gather_nccl: if with_input_all_gather_nccl:
assert not isinstance( if not isinstance(inputmat, QuantizedTensor):
inputmat, QuantizedTensor columnwise_usage = backward_needs_input and isinstance(
), "All gather of fp8 input is not supported" input_quantizer, MXFP8Quantizer
)
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat_total, _ = gather_along_first_dim( inputmat_total, _ = gather_along_first_dim(
inputmat, inputmat,
...@@ -269,9 +274,14 @@ class _Linear(torch.autograd.Function): ...@@ -269,9 +274,14 @@ class _Linear(torch.autograd.Function):
# to gather the input. For MXFP8, columnwise only data # to gather the input. For MXFP8, columnwise only data
# can be allgathered. # can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
inputmat.update_usage(rowwise_usage=False) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
saved_inputmat = inputmat saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensor):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading: if cpu_offloading:
set_offloading_param(weight, "weight_offloading", True) set_offloading_param(weight, "weight_offloading", True)
set_offloading_param(weightmat, "weight_offloading", True) set_offloading_param(weightmat, "weight_offloading", True)
...@@ -489,7 +499,7 @@ class _Linear(torch.autograd.Function): ...@@ -489,7 +499,7 @@ class _Linear(torch.autograd.Function):
quantizer = None quantizer = None
if ctx.fp8: if ctx.fp8:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat, inputmat,
...@@ -1211,9 +1221,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -1211,9 +1221,6 @@ class Linear(TransformerEngineBaseModule):
self.quantizers["scaling_fwd"][ self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_size = self.tp_size
else: else:
# set grad_output_quantizer with amax epsilon and power_2_scale # set grad_output_quantizer with amax epsilon and power_2_scale
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
...@@ -1231,6 +1238,3 @@ class Linear(TransformerEngineBaseModule): ...@@ -1231,6 +1238,3 @@ class Linear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_size = self.tp_size
...@@ -283,7 +283,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -283,7 +283,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
recipe_state = fp8_meta[fp8_meta_key] recipe_state = fp8_meta[fp8_meta_key]
# Reallocate amax history if needed # Reallocate amax history if needed
if recipe.mxfp8(): if not recipe.delayed():
continue continue
current_length = recipe_state.amax_history.size(0) current_length = recipe_state.amax_history.size(0)
......
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..utils import devices_match, non_tn_fp8_gemm_supported from ..utils import canonicalize_process_group, devices_match, non_tn_fp8_gemm_supported
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..constants import dist_group_type from ..constants import dist_group_type
...@@ -194,7 +194,6 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -194,7 +194,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
"""amax reduction options""" """amax reduction options"""
with_amax_reduction: bool with_amax_reduction: bool
amax_reduction_group: Optional[dist_group_type] amax_reduction_group: Optional[dist_group_type]
amax_reduction_size: Optional[int]
"""Options about how to quantize the tensor""" """Options about how to quantize the tensor"""
force_pow_2_scales: bool force_pow_2_scales: bool
amax_epsilon: float amax_epsilon: float
...@@ -208,7 +207,6 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -208,7 +207,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
columnwise: bool = True, columnwise: bool = True,
with_amax_reduction: bool = False, with_amax_reduction: bool = False,
amax_reduction_group: Optional[dist_group_type] = None, amax_reduction_group: Optional[dist_group_type] = None,
amax_reduction_size: Optional[int] = 1,
force_pow_2_scales: bool = False, force_pow_2_scales: bool = False,
amax_epsilon: float = 0.0, amax_epsilon: float = 0.0,
) -> None: ) -> None:
...@@ -218,7 +216,6 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -218,7 +216,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
self.dtype = fp8_dtype self.dtype = fp8_dtype
self.with_amax_reduction = with_amax_reduction self.with_amax_reduction = with_amax_reduction
self.amax_reduction_group = amax_reduction_group self.amax_reduction_group = amax_reduction_group
self.amax_reduction_size = amax_reduction_size
self.force_pow_2_scales = force_pow_2_scales self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon self.amax_epsilon = amax_epsilon
...@@ -327,6 +324,10 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -327,6 +324,10 @@ class Float8CurrentScalingQuantizer(Quantizer):
quantizer=self, quantizer=self,
) )
def _canonicalized_amax_reduction_group(self) -> dist_group_type:
"""Get process group for amax reduction"""
return canonicalize_process_group(self.amax_reduction_group)
class Float8Tensor(Float8TensorBase, QuantizedTensor): class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data """Experimental tensor class with FP8 data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment