# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Quantization utilities for TransformerEngine""" from __future__ import annotations import abc import itertools import functools import warnings import os from contextlib import contextmanager from collections import deque from typing import Callable, List, Optional, Dict, Any, Tuple, Union import torch import transformer_engine_torch as tex from transformer_engine.common.recipe import ( Recipe, DelayedScaling, Format, MXFP8BlockScaling, Float8CurrentScaling, Float8BlockScaling, NVFP4BlockScaling, CustomRecipe, ) from .constants import dist_group_type from .utils import get_device_compute_capability from .jit import jit_fuser __all__ = [ "autocast", "quantized_model_init", "is_fp8_available", "is_mxfp8_available", "is_fp8_block_scaling_available", "is_nvfp4_available", "get_default_recipe", ] @functools.lru_cache(maxsize=None) def check_fp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" if get_device_compute_capability() >= (9, 0): # hopper and above return True, "" if get_device_compute_capability() < (8, 9): # pre-ada return False, "Device compute capability 8.9 or higher required for FP8 execution." if tex.get_cublasLt_version() < 120103: return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." if float(torch.version.cuda) < 12.1: return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." return True, "" @functools.lru_cache(maxsize=None) def check_mxfp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" if get_device_compute_capability() >= (12, 0): return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" return False, "Device compute capability 10.0 or higher required for MXFP8 execution." @functools.lru_cache(maxsize=None) def check_nvfp4_support() -> Tuple[bool, str]: """Return if nvfp4 support is available""" if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" return False, "Device compute capability 10.0 or higher required for NVFP4 execution." @functools.lru_cache(maxsize=None) def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: return True, "" return ( False, "FP8 block scaled GEMM requires compute capability 9.0 or higher and CUDA >= 12.9.", ) def check_recipe_support(recipe: Recipe) -> None: """Check if the given recipe is supported.""" recipe_supported = True unsupported_reason = "" if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)): recipe_supported, unsupported_reason = check_fp8_support() elif isinstance(recipe, Float8BlockScaling): recipe_supported, unsupported_reason = check_fp8_block_scaling_support() elif isinstance(recipe, MXFP8BlockScaling): recipe_supported, unsupported_reason = check_mxfp8_support() assert recipe_supported, unsupported_reason def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" if check_mxfp8_support()[0]: return MXFP8BlockScaling() if get_device_compute_capability() >= (12, 0): # This is a temporary restriction until MXFP8 is supported for all gemm layouts. return Float8CurrentScaling() return DelayedScaling() def get_default_recipe() -> Recipe: """Returns the default training recipe based on available device.""" return get_default_fp8_recipe() def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype: """Get fp8 data type according to recipe and tensor""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor ): return torch.float8_e4m3fn return torch.float8_e5m2 def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get fp8 data type according to recipe and tensor""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor ): return tex.DType.kFloat8E4M3 return tex.DType.kFloat8E5M2 def get_fp4_te_dtype(fp4_recipe: Recipe) -> tex.DType: """Get fp4 data type according to recipe and tensor""" if fp4_recipe.fp4_format == Format.E2M1: return tex.DType.kFloat4E2M1 raise ValueError(f"Unsupported FP4 format: {fp4_recipe.fp4_format}") def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get max representible FP8 value.""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor ): return Format.E4M3.value.max_fwd return Format.E5M2.value.max_fwd def is_fp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: """ Determine if FP8 support is available for the delayed scaling and per tensor current scaling recipe. Parameters ---------- return_reason : bool, optional If ``False`` (default), return only a boolean indicating availability. If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides a human-readable explanation when required support is not available. The reason will be an empty string if support for FP8 is available. """ if return_reason: return check_fp8_support() return check_fp8_support()[0] def is_mxfp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: """ Determine if support is available for the MXFP8 recipe. Parameters ---------- return_reason : bool, optional If ``False`` (default), return only a boolean indicating availability. If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides a human-readable explanation when required support is not available. The reason will be an empty string if support for MXFP8 is available. """ if return_reason: return check_mxfp8_support() return check_mxfp8_support()[0] def is_fp8_block_scaling_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: """ Determine if support is available for the FP8 block scaling recipe. Parameters ---------- return_reason : bool, optional If ``False`` (default), return only a boolean indicating availability. If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides a human-readable explanation when required support is not available. The reason will be an empty string if support for FP8 block scaling is available. """ if return_reason: return check_fp8_block_scaling_support() return check_fp8_block_scaling_support()[0] def is_nvfp4_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: """ Determine if support is available for the NVFP4 recipe. Parameters ---------- return_reason : bool, optional If ``False`` (default), return only a boolean indicating availability. If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides a human-readable explanation when required support is not available. The reason will be an empty string if support for NVFP4 is available. """ if return_reason: return check_nvfp4_support() return check_nvfp4_support()[0] class FP8GlobalStateManager: """Class to keep track of and manipulate the global FP8 state at different stages of execution. """ FP8_ENABLED = False FP8_CALIBRATION = False FP8_RECIPE = None FP8_DISTRIBUTED_GROUP = None FP8_PARAMETERS = False HIGH_PRECISION_INIT_VAL = False IS_FIRST_FP8_MODULE = False FP8_GRAPH_CAPTURING = False AUTOCAST_DEPTH = 0 global_amax_buffer = {} global_amax_history_buffer = {} global_scale_buffer = {} fp8_tensors_recompute_buffer = [] fp8_available = None reason_for_no_fp8 = "" autocast_arguments = {} skip_fp8_weight_update_tensor = None mxfp8_available = None reason_for_no_mxfp8 = "" fp8_block_scaling_available = None reason_for_no_fp8_block_scaling = None nvfp4_available = None reason_for_no_nvfp4 = "" @classmethod def reset(cls) -> None: """Reset the global state""" cls.FP8_ENABLED = False cls.FP8_CALIBRATION = False cls.FP8_RECIPE = None cls.FP8_DISTRIBUTED_GROUP = None cls.FP8_PARAMETERS = False cls.HIGH_PRECISION_INIT_VAL = False cls.IS_FIRST_FP8_MODULE = False cls.FP8_GRAPH_CAPTURING = False cls.AUTOCAST_DEPTH = 0 cls.global_amax_buffer = {} cls.global_amax_history_buffer = {} cls.global_scale_buffer = {} cls.fp8_tensors_recompute_buffer = [] cls.fp8_available = None cls.reason_for_no_fp8 = "" cls.autocast_arguments = {} cls.skip_fp8_weight_update_tensor = None cls.mxfp8_available = None cls.reason_for_no_mxfp8 = "" cls.fp8_block_scaling_available = None cls.reason_for_no_fp8_block_scaling = "" @classmethod def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: """`skip_fp8_weight_update_tensor` inplace setter.""" if cls.skip_fp8_weight_update_tensor is None: cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda") cls.skip_fp8_weight_update_tensor.fill_(skip) @classmethod def get_skip_fp8_weight_update_tensor(cls) -> None: """`skip_fp8_weight_update_tensor` getter.""" return cls.skip_fp8_weight_update_tensor @classmethod def is_fp8_available(cls) -> Tuple[bool, str]: """Return if fp8 support is available""" return check_fp8_support() @classmethod def is_mxfp8_available(cls) -> Tuple[bool, str]: """Return if MXFP8/current scaling support is available.""" return check_mxfp8_support() @classmethod def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: """Return if Float8 block scaling support is available.""" return check_fp8_block_scaling_support() @classmethod def is_nvfp4_available(cls) -> Tuple[bool, str]: """Return if NVFP4 support is available.""" return check_nvfp4_support() @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" if forward: return "scaling_fwd" return "scaling_bwd" @staticmethod def get_fwd_bwd_key(forward: bool = True) -> str: """Convert bool `forward` to string.""" return "forward" if forward else "backward" @classmethod def get_buffer_info(cls) -> str: """ Returns a key for `fp8_meta` that stores the module's index in the global buffers along with autocast information. """ return "buffer_index_and_autocast_key" @classmethod def get_key_in_buffer( cls, forward: bool, fp8_recipe: Recipe, fp8_group: dist_group_type, ) -> str: """Returns a key into the global FP8 buffers.""" autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) fwd_bwd_key = cls.get_fwd_bwd_key(forward) return f"{fwd_bwd_key}_{autocast_key}" @classmethod def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]: """Splits buffer key into relevant parts.""" forward, autocast_key = key.split("_", 1) forward = forward == "forward" return forward, autocast_key @classmethod def add_fp8_tensors_to_global_buffer( cls, fp8_meta: Dict[str, Any], ) -> None: """ Delayed scaling only. The amax reduction process happens completely outside the FP8 modules. To participate in the reduction, the only role played by a module is to call this function in order to append it's FP8 tensor into a global buffer. There are 5 global buffers maintained, one each for amax, amax history, scale, scale-inverse, and non-weight-mask. Each buffer has keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix to indicate the type of FP8 tensor, since the forward and backward reductions happen separately. Note: For CG capture, this method is called from the graphed wrapper. For non CG case, it's called from within the module. """ # delayed scaling only function, noop for any other recipe if not fp8_meta["recipe"].delayed(): return # Every module must call this function exactly once since # the amax tensors are static. Ensures that compatibility # with non-graphed modules is maintained. index_in_buffer = cls.get_buffer_info() # Same index for fwd/bwd fp8 tensors. if index_in_buffer in fp8_meta: return fp8_meta[index_in_buffer] = [] for forward in (True, False): fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) if fp8_meta_tensor_key not in fp8_meta: # Handles non-parameter FP8 modules, e.g. DPA. continue key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) if key not in cls.global_amax_buffer: cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] else: cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) cls.global_amax_history_buffer[key].append( fp8_meta[fp8_meta_tensor_key].amax_history ) cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) fp8_meta[index_in_buffer].append(key) @classmethod def is_fp8_enabled(cls) -> bool: """Is FP8 enabled""" return cls.FP8_ENABLED @classmethod def is_fp8_calibration(cls) -> bool: """Is FP8 calibration""" return cls.FP8_CALIBRATION @classmethod def with_fp8_parameters(cls) -> bool: """Should the parameters be stored as FP8""" return cls.FP8_PARAMETERS @classmethod def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" return cls.HIGH_PRECISION_INIT_VAL @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing() @classmethod def is_first_fp8_module(cls): """Returns `True` only the first time when called multiple times from within the same `autocast` context. """ tmp = cls.IS_FIRST_FP8_MODULE cls.IS_FIRST_FP8_MODULE = False return tmp @classmethod def get_fp8_recipe(cls) -> Recipe: """Return the fp8 recipe""" if cls.FP8_RECIPE is not None: return cls.FP8_RECIPE return get_default_fp8_recipe() @classmethod def get_fp8_group(cls) -> Union[dist_group_type, None]: """Return the fp8 group for scale/amax comm""" return cls.FP8_DISTRIBUTED_GROUP @classmethod def get_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]: """FP8 autocast state getter""" return ( cls.FP8_ENABLED, cls.FP8_CALIBRATION, cls.FP8_RECIPE, cls.FP8_DISTRIBUTED_GROUP, cls.IS_FIRST_FP8_MODULE, cls.FP8_GRAPH_CAPTURING, ) @classmethod def set_autocast_state( cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool] ) -> None: """FP8 autocast state setter""" ( cls.FP8_ENABLED, cls.FP8_CALIBRATION, cls.FP8_RECIPE, cls.FP8_DISTRIBUTED_GROUP, cls.IS_FIRST_FP8_MODULE, cls.FP8_GRAPH_CAPTURING, ) = fp8_state @staticmethod def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None: """Reduce tensor across given group.""" if torch.distributed.is_initialized(): torch.distributed.all_reduce( tensor, op=torch.distributed.ReduceOp.MAX, group=group, async_op=False, ) @classmethod def reduce_and_update_fp8_tensors( cls, forward: bool = True, ) -> None: """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" # global_amax_buffer should only be non-empty for fp8 delayed scaling for buffer_key, amax_buffer in cls.global_amax_buffer.items(): # Check for forward or backward reduction. fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) if fwd_update != forward: continue if len(amax_buffer) == 0: continue # Retrieve autocast specific args and concat amaxes. recipe, group = cls.autocast_arguments[autocast_key] contiguous_amax = torch.cat(amax_buffer) # Reduction. if ( recipe.reduce_amax and torch.distributed.is_initialized() and torch.distributed.get_world_size(group=group) > 1 ): cls.reduce_tensor_across_group_op_max(contiguous_amax, group) # Amax and scale update. unfused_update = ( bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0"))) or callable(recipe.amax_compute_algo) or callable(recipe.scaling_factor_compute_algo) ) if not unfused_update: tex.fused_amax_and_scale_update_after_reduction( contiguous_amax, cls.global_amax_history_buffer[buffer_key], cls.global_scale_buffer[buffer_key], recipe.amax_compute_algo, get_fp8_te_dtype(recipe, forward), recipe.margin, ) else: split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) for amax_history, scale in zip( cls.global_amax_history_buffer[buffer_key], cls.global_scale_buffer[buffer_key], ): _amax_and_scale_update( amax_history, scale, get_fp8_max(recipe, forward), recipe ) @classmethod def get_unique_autocast_key( cls, recipe: Optional[Recipe] = None, group: Optional[dist_group_type] = None, ): """ For FP8, each autocast can be uniquely identified by the recipe and fp8 group. Safely using `hash` as we never cross checkpoint boundaries. """ return f"{str(recipe)}:{hash(group)}" @classmethod def autocast_enter( cls, enabled: bool = False, calibrating: bool = False, fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, _graph: bool = False, ) -> None: """Set state and tracking variables for entry into FP8 region.""" fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating cls.FP8_RECIPE = fp8_recipe cls.FP8_DISTRIBUTED_GROUP = fp8_group cls.FP8_GRAPH_CAPTURING = _graph if cls.AUTOCAST_DEPTH == 0: cls.IS_FIRST_FP8_MODULE = True cls.AUTOCAST_DEPTH += 1 if enabled: fp8_available, reason_for_no_fp8 = cls.is_fp8_available() assert fp8_available, reason_for_no_fp8 if isinstance(fp8_recipe, MXFP8BlockScaling): mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() assert mxfp8_available, reason_for_no_mxfp8 if isinstance(fp8_recipe, Float8BlockScaling): fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() assert fp8_block_available, reason_for_no_fp8_block if isinstance(fp8_recipe, NVFP4BlockScaling): nvfp4_available, reason_for_no_nvfp4 = cls.is_nvfp4_available() assert nvfp4_available, reason_for_no_nvfp4 @classmethod def autocast_exit(cls, enabled: bool, _graph: bool) -> None: """Set state and tracking variables for exit from FP8 region.""" cls.AUTOCAST_DEPTH -= 1 # Reduce only the non-FP8 weight modules here. # FP8 weight modules are reduced at the end of the optimizer # step after the weight amax is populated. if enabled and cls.AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): # delayed scaling only function, for other recipes (current scaling with any granularity), # this is noop for other recipes because cls.global_amax_buffer is empty list cls.reduce_and_update_fp8_tensors(forward=True) @classmethod def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: """Copy the scaling factors and amaxes for recompute forward phase to ensure both forward steps are numerically same. """ # delayed scaling only function, noop for any other recipe if not fp8_meta["recipe"].delayed(): return buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" to_copy = [ fp8_meta["scaling_fwd"].amax_history.clone(), fp8_meta["scaling_fwd"].scale.clone(), ] if buffer_position_key in fp8_meta: cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy) else: if len(cls.fp8_tensors_recompute_buffer) == 0: cls.fp8_tensors_recompute_buffer = [deque()] else: cls.fp8_tensors_recompute_buffer.append(deque()) cls.fp8_tensors_recompute_buffer[-1].append(to_copy) fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1 @classmethod def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: """Switch to the copied scaling factors and amaxes from phase 1 forward for indentical numerical outputs. """ # delayed scaling only function, noop for any other recipe if not fp8_meta["recipe"].delayed(): return # Store updated amaxes and scales from phase 1 post forward. fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history.clone() fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale.clone() # Retrieve stashed amaxes and scales from phase 1 pre forward. buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() # Replace amaxes and scales with stashed values for phase 2 forward fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1]) @staticmethod def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" # delayed scaling only function, noop for any other recipe if not fp8_meta["recipe"].delayed(): return fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) @contextmanager def fp8_model_init( enabled: bool = True, recipe: Optional[Recipe] = None, preserve_high_precision_init_val: bool = False, ) -> None: """ .. warning:: fp8_model_init is deprecated and will be removed in a future release. Use quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...) instead. """ warnings.warn( "fp8_model_init is deprecated and will be removed in a future release. " "Use quantized_model_init(" "enabled=..., recipe=..., preserve_high_precision_init_val=...) instead.", category=DeprecationWarning, stacklevel=2, ) # Call new implementation. with quantized_model_init( enabled=enabled, recipe=recipe, preserve_high_precision_init_val=preserve_high_precision_init_val, ): yield @contextmanager def quantized_model_init( enabled: bool = True, recipe: Optional[Recipe] = None, preserve_high_precision_init_val: bool = False, ) -> None: """ Context manager for initialization of quantized parameters. Example usage: .. code-block:: python with quantized_model_init(enabled=True): model = transformer_engine.pytorch.Linear(768, 768) # Preserving high precision initial value to initialize master weight with quantized_model_init(enabled=True, preserve_high_precision_init_val=True): model = transformer_engine.pytorch.Linear(768, 768) master_weight = model.weight.get_high_precision_init_val() model.weight.clear_high_precision_init_val() Parameters ---------- enabled: bool, default = `True` when enabled, Transformer Engine modules created inside this `quantized_model_init` region will hold only quantized copies of its parameters, as opposed to the default behavior where both higher precision and quantized copies are present. Setting this option to `True` may result in lower memory consumption and is especially useful for scenarios like: * full model training using optimizer with master weights, where the high precision copies of weights are already present in the optimizer. * inference, where only the quantized copies of the parameters are used. * LoRA-like fine-tuning, where the main parameters of the model do not change. recipe: transformer_engine.common.recipe.Recipe, default = `None` Recipe used to create the parameters. If left to None, it uses the default recipe. preserve_high_precision_init_val: bool, default = `False` when enabled, store the high precision tensor used to initialize quantized parameters in CPU memory, and add two function attributes named `get_high_precision_init_val()` and `clear_high_precision_init_val()` to quantized parameters to get/clear this high precision tensor. The purpose is that users can use this high-precision copy to initialize master weights, avoiding the loss of precision that can occur when using quantized parameters directly. Note that after the master weights are initialized, users should call `clear_high_precision_init_val()` to release this CPU memory. This functionality is *EXPERIMENTAL*. """ _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL FP8GlobalStateManager.FP8_PARAMETERS = enabled FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val try: yield finally: FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val @contextmanager def fp8_autocast( enabled: bool = True, calibrating: bool = False, fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, _graph: bool = False, ) -> None: """ .. warning:: fp8_autocast is deprecated and will be removed in a future release. Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead. """ warnings.warn( "fp8_autocast is deprecated and will be removed in a future release. " "Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead.", category=DeprecationWarning, stacklevel=2, ) # Call new implementation. with autocast( enabled=enabled, calibrating=calibrating, recipe=fp8_recipe, amax_reduction_group=fp8_group, _graph=_graph, ): yield @contextmanager def autocast( enabled: bool = True, calibrating: bool = False, recipe: Optional["Recipe"] = None, amax_reduction_group: Optional["dist_group_type"] = None, _graph: bool = False, ) -> None: """ Context manager for quantization schemes like FP8 or FP4. .. code-block:: python with autocast(enabled=True): out = model(inp) .. note:: Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors with shapes where both dimensions are divisible by 16. In terms of the input to the full Transformer network, this typically requires padding sequence length to be multiple of 16. .. note:: When :attr:`recipe.reduce_amax==True`, any module must not be invoked more than once inside a single `autocast` region. This is unsupported behavior because the amax reduction is handled during the exit of the `autocast` context. Calling the same module more than once inside an `autocast` region overrides the amax tensors before reduction can occur. Parameters ---------- enabled: bool, default = `True` whether or not to enable low precision quantization (FP8/FP4). calibrating: bool, default = `False` calibration mode allows collecting statistics such as amax and scale data of quantized tensors even when executing without quantization enabled. This is useful for saving an inference ready checkpoint while training using a higher precision. recipe: recipe.Recipe, default = `None` recipe used for low precision quantization. amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = `None` distributed group over which amaxes for the quantized tensors are reduced at the end of each training step. """ if enabled: check_recipe_support(recipe) # Save current state so we always restore it on exit. fp8_state = FP8GlobalStateManager.get_autocast_state() FP8GlobalStateManager.autocast_enter( enabled=enabled, calibrating=calibrating, fp8_recipe=recipe, fp8_group=amax_reduction_group, _graph=_graph, ) try: yield finally: FP8GlobalStateManager.set_autocast_state(fp8_state) FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: """Update amax history and set next amax to zero.""" if amax_history.shape[0] > 1: new_amax_history = torch.roll(amax_history, -1, 0) amax_history.copy_(new_amax_history) amax_history[0].fill_(0.0) return amax_history @torch.jit.script def _default_get_amax_and_update_history( amax_history: torch.Tensor, amax_compute_algo: str, ) -> Tuple[torch.Tensor, torch.Tensor]: """Default function to obtain amax from history.""" if amax_compute_algo == "max": amax = torch.max(amax_history, dim=0).values else: # amax_compute_algo == "most_recent" amax = amax_history[0].clone() amax_history = _update_amax_history(amax_history) return amax_history, amax @jit_fuser def _default_sf_compute( amax: torch.Tensor, scale: torch.Tensor, fp8_max: float, margin: int, _fp32_max: float = torch.finfo(torch.float32).max, # finfo not available in jitter ) -> torch.Tensor: """Default function to convert amax to scaling factor. Computing the scaling factor requires consideration of the following scenarios: 1. amax == 0: No action is possible, set scale to the previous scale (or 1). 2. 0 < amax < tiny_amax The amax is too tiny that the scale becomes infinite in FP32. Set scale = FP32_max 3. tiny_amax <= amax < FP32_max: Set scale = FP8_max (or scaled_max) / amax 4. When amax == inf or amax == nan: No action is possible, set scale to the previous scale (or 1). """ sf = (fp8_max / amax) / (2**margin) sf = torch.where(amax > 0.0, sf, scale) sf = torch.where(torch.isfinite(amax), sf, scale) sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf) scale.copy_(sf) return scale def _compute_amax_and_update_history( amax_history: torch.Tensor, amax_compute_algo: Union[Callable, str], ) -> Tuple[torch.Tensor, torch.Tensor]: """Obtain the amax from the history.""" if callable(amax_compute_algo): amax = amax_compute_algo(amax_history) amax_history = _update_amax_history(amax_history) return amax_history, amax return _default_get_amax_and_update_history( amax_history, amax_compute_algo, ) def _compute_scaling_factor( amax: torch.Tensor, scale: torch.Tensor, fp8_max: float, recipe: DelayedScaling, ) -> torch.Tensor: """Convert amax to scaling factor.""" if recipe.scaling_factor_compute_algo is None: return _default_sf_compute( amax, scale, fp8_max, recipe.margin, ) return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) def _amax_and_scale_update( amax_history: torch.Tensor, scale: torch.Tensor, fp8_max: float, recipe: DelayedScaling, ) -> None: """Updates FP8 meta tensors.""" new_amax_history, amax = _compute_amax_and_update_history( amax_history, recipe.amax_compute_algo, ) new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe) scale.copy_(new_scale) amax_history.copy_(new_amax_history) def split_and_copy( buffer: torch.Tensor, outputs: List[torch.Tensor], chunk_sizes: List[int], ) -> None: """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" splits = buffer.split(chunk_sizes) torch._foreach_copy_(outputs, splits) class RecipeState(abc.ABC): """Configuration and state for a quantization recipe. This is a builder class for quantizers, which are in turn builder classes for quantized tensors. This class may pack together the state for multiple quantizers, which is helpful for applying fused kernels with less overhead. """ @staticmethod def create( recipe: Recipe, *, mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, ) -> RecipeState: """Factory method to create the state for a quantization recipe Parameters ---------- recipe: Recipe Quantization recipe. mode: {"forward", "backward"} Training stage where quantization will be performed. num_quantizers: int, default = 1 Number of quantizers to create state for. device: torch.device, default = default CUDA device Device for quantized tensors. Returns ------- RecipeState: Quantization recipe state. """ cls = None if recipe.delayed(): cls = DelayedScalingRecipeState elif recipe.mxfp8(): cls = MXFP8BlockScalingRecipeState elif recipe.float8_current_scaling(): cls = Float8CurrentScalingRecipeState elif recipe.float8_block_scaling(): cls = Float8BlockScalingRecipeState elif recipe.nvfp4(): cls = NVFP4BlockScalingRecipeState elif recipe.custom(): cls = CustomRecipeState else: raise ValueError(f"{recipe.__class__.__name__} is not supported") return cls( recipe, mode=mode, num_quantizers=num_quantizers, device=device, ) @abc.abstractmethod def make_quantizers(self) -> list: """Convert recipe state to quantizers. Quantizers are builder classes for quantized tensors. They are typically used to convert a high-precision tensor (e.g. in FP32 or BF16) into a quantized tensor (e.g. in FP8). """ class DelayedScalingRecipeState(RecipeState): """State for FP8 quantization with per-tensor delayed scaling. Delayed scaling recipe requires a scaling factor (applied when casting to FP8) and a history of max-abs values ("amax") from recent FP8 casts for updating the scaling factor. The scale update is handled externally by `FP8GlobalStateManager`. """ recipe: DelayedScaling mode: str dtype: tex.DType scale: torch.Tensor amax_history: torch.Tensor def __init__( self, recipe: DelayedScaling, *, mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, ) -> None: self.recipe = recipe self.mode = mode self.num_quantizers = num_quantizers self.dtype = get_fp8_te_dtype(recipe, mode == "forward") # Allocate buffers if device is None: device = torch.device("cuda") self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device) self.amax_history = torch.zeros( recipe.amax_history_len, num_quantizers, dtype=torch.float32, device=device, ) def make_quantizers(self) -> list: # TODO(ksivamani); Find better design for this, adding here to avoid circular import. from .tensor.float8_tensor import Float8Quantizer return [ Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype) for i in range(self.num_quantizers) ] class Float8CurrentScalingRecipeState(RecipeState): """Configuration for Per-tensor current scaling quantization. Per-tensor current quantization does not require state. """ recipe: Float8CurrentScaling mode: str dtype: tex.DType device: torch.device def __init__( self, recipe: Float8CurrentScaling, *, mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, ) -> None: self.recipe = recipe self.mode = mode self.num_quantizers = num_quantizers self.dtype = get_fp8_te_dtype(recipe, mode == "forward") # Allocate buffers if device is None: device = torch.device("cuda") self.device = device def make_quantizers(self) -> list: from .tensor.float8_tensor import Float8CurrentScalingQuantizer return [ Float8CurrentScalingQuantizer( self.dtype, device=self.device, force_pow_2_scales=self.recipe.use_power_2_scales ) for i in range(self.num_quantizers) ] class MXFP8BlockScalingRecipeState(RecipeState): """Configuration for MXFP8 quantization. MXFP8 quantization does not require state. """ recipe: MXFP8BlockScaling mode: str dtype: tex.DType def __init__( self, recipe: MXFP8BlockScaling, *, mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, ) -> None: self.recipe = recipe self.mode = mode self.num_quantizers = num_quantizers self.dtype = get_fp8_te_dtype(recipe, mode == "forward") # Allocate buffers if device is None: device = torch.device("cuda") def make_quantizers(self) -> list: # TODO(ksivamani); Find better design for this, adding here to avoid circular import. from .tensor.mxfp8_tensor import MXFP8Quantizer return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] class Float8BlockScalingRecipeState(RecipeState): """Configuration for Float8BlockScaling quantization. Float8BlockScaling quantization does not require state, but different quantizers use different modes. """ recipe: Float8BlockScaling mode: str qx_dtype: tex.DType qw_dtype: tex.DType qgrad_dtype: tex.DType def __init__( self, recipe: Float8BlockScaling, *, mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, ) -> None: self.recipe = recipe self.mode = mode self.num_quantizers = num_quantizers self.qx_dtype = get_fp8_te_dtype(recipe, True) self.qw_dtype = get_fp8_te_dtype(recipe, True) self.qgrad_dtype = get_fp8_te_dtype(recipe, False) # Allocate buffers if device is None: device = torch.device("cuda") self.device = device def make_quantizers(self) -> list: # TODO(ksivamani); Find better design for this, adding here to avoid circular import. from .tensor.float8_blockwise_tensor import Float8BlockQuantizer if self.mode == "forward": # The index convention (coming from base.py set_meta_tensor) # is somewhat awkward, and doesn't play nicely with QuantizeOp, # which is not associated with a GEMM. assert self.num_quantizers % 3 == 0 # x, w, output per gemm return list( itertools.chain.from_iterable( [ [ Float8BlockQuantizer( fp8_dtype=self.qx_dtype, rowwise=True, columnwise=True, amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, block_scaling_dim=self.recipe.x_block_scaling_dim, ), Float8BlockQuantizer( fp8_dtype=self.qw_dtype, rowwise=True, columnwise=True, amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, block_scaling_dim=self.recipe.w_block_scaling_dim, ), Float8BlockQuantizer( fp8_dtype=self.qx_dtype, rowwise=True, columnwise=True, amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, block_scaling_dim=self.recipe.x_block_scaling_dim, ), ] for _ in range(self.num_quantizers // 3) ] ) ) assert self.mode == "backward", f"Unexpected mode {self.mode}" assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm return list( itertools.chain.from_iterable( [ [ Float8BlockQuantizer( fp8_dtype=self.qgrad_dtype, rowwise=True, columnwise=True, amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, block_scaling_dim=self.recipe.grad_block_scaling_dim, ), Float8BlockQuantizer( fp8_dtype=self.qgrad_dtype, rowwise=True, columnwise=True, amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, block_scaling_dim=self.recipe.grad_block_scaling_dim, ), ] for _ in range(self.num_quantizers // 2) ] ) ) class NVFP4BlockScalingRecipeState(RecipeState): """Configuration for NVFP4 quantization. NVFP4 quantization does not require state. """ recipe: NVFP4BlockScaling mode: str dtype: tex.DType def __init__( self, recipe: NVFP4BlockScaling, *, mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, ) -> None: self.recipe = recipe self.mode = mode self.num_quantizers = num_quantizers self.dtype = get_fp4_te_dtype(recipe) # Allocate buffers if device is None: device = torch.device("cuda") def make_quantizers(self) -> list: from .tensor.nvfp4_tensor import NVFP4Quantizer # The index convention (coming from base.py set_meta_tensor) # is somewhat awkward. It assumes forward quantizers are # ordered [input, weight, output, ...] and backward quantizers # are ordered [grad_output, grad_input, ...]. This doesn't # play nicely with fusible ops: Linear op doesn't own output # or grad input quantizers, Quantize op only owns input and # grad output quantizers. if self.mode == "forward": def _make_quantizer(idx: int) -> NVFP4Quantizer: qparams = ( self.recipe.fp4_quant_fwd_weight if idx % 3 == 1 else self.recipe.fp4_quant_fwd_inp ) return NVFP4Quantizer( fp4_dtype=self.dtype, rowwise=True, columnwise=True, with_rht=qparams.random_hadamard_transform, with_post_rht_amax=qparams.random_hadamard_transform, with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, ) return [_make_quantizer(idx) for idx in range(self.num_quantizers)] if self.mode == "backward": return [ NVFP4Quantizer( fp4_dtype=self.dtype, rowwise=True, columnwise=True, with_rht=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, ) for _ in range(self.num_quantizers) ] raise RuntimeError(f"Unexpected recipe mode ({self.mode})") class CustomRecipeState(RecipeState): """State for CustomRecipe: produce quantizers per tensor.""" recipe: CustomRecipe mode: str num_quantizers: int device: Optional[torch.device] def __init__( self, recipe: CustomRecipe, *, mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, ) -> None: self.recipe = recipe self.mode = mode self.num_quantizers = num_quantizers if device is None: device = torch.device("cuda") self.device = device if getattr(recipe, "qfactory", None) is None: raise ValueError("CustomRecipe requires `qfactory`.") def make_quantizers(self) -> list: qfactory = self.recipe.qfactory out = [] # TODO(negvet): make_quantizers() should take roles from the operation # Hardcode linear-specific roles for now roles: List[str] if self.mode == "forward": roles = [ ("linear_input", "linear_weight", "linear_output")[i % 3] for i in range(self.num_quantizers) ] elif self.mode == "backward": roles = [ ("linear_grad_output", "linear_grad_input")[i % 2] for i in range(self.num_quantizers) ] else: roles = ["unknown"] * self.num_quantizers for i in range(self.num_quantizers): # Get quantizer from the user defined factory quantizer = qfactory(roles[i]) out.append(quantizer) return out