Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Quantization utilities for TransformerEngine"""
from __future__ import annotations
import abc
import itertools
import functools
import warnings
import os
from contextlib import contextmanager
from collections import deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union
import torch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import (
Recipe,
DelayedScaling,
Format,
MXFP8BlockScaling,
Float8CurrentScaling,
Float8BlockScaling,
NVFP4BlockScaling,
CustomRecipe,
)
from .constants import dist_group_type
from .utils import get_device_compute_capability
from .jit import jit_fuser
from torch.utils.cpp_extension import IS_HIP_EXTENSION
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
int8_simulation_fp8_tensorwise = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE", "0")))
blockwise_fp8_block_len = int(os.getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN", "128"))
__all__ = [
"autocast",
"quantized_model_init",
"is_fp8_available",
"is_mxfp8_available",
"is_fp8_block_scaling_available",
"is_nvfp4_available",
"get_default_recipe",
]
if IS_HIP_EXTENSION:
from transformer_engine.pytorch.utils import is_K100_AI, is_BW
@functools.lru_cache(maxsize=None)
def check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if IS_HIP_EXTENSION:
if (is_K100_AI() or is_BW()) and int8_simulation_fp8:
return True, "DCU turn on fp8 simulation with int8"
else:
return False, "DCU not support fp8 for now"
else:
if get_device_compute_capability() >= (9, 0): # hopper and above
return True, ""
if get_device_compute_capability() < (8, 9): # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if tex.get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if float(torch.version.cuda) < 12.1:
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
return True, ""
@functools.lru_cache(maxsize=None)
def check_mxfp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if get_device_compute_capability() >= (12, 0):
return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
if get_device_compute_capability() >= (10, 0): # blackwell and above
return True, ""
return False, "Device compute capability 10.0 or higher required for MXFP8 execution."
@functools.lru_cache(maxsize=None)
def check_nvfp4_support() -> Tuple[bool, str]:
"""Return if nvfp4 support is available"""
if IS_HIP_EXTENSION:
return False, "NVFP4 is not supported on rocm platform."
else:
if get_device_compute_capability() >= (10, 0): # blackwell and above
return True, ""
return False, "Device compute capability 10.0 or higher required for NVFP4 execution."
@functools.lru_cache(maxsize=None)
def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available"""
if IS_HIP_EXTENSION:
if is_K100_AI() or is_BW() and int8_simulation_fp8:
return True, ""
else:
return False, "DCU not support block_scaling fp8 for now"
if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9:
return True, ""
return (
False,
"FP8 block scaled GEMM requires compute capability 9.0 or higher and CUDA >= 12.9.",
)
def check_recipe_support(recipe: Recipe) -> None:
"""Check if the given recipe is supported."""
recipe_supported = True
unsupported_reason = ""
if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)):
recipe_supported, unsupported_reason = check_fp8_support()
elif isinstance(recipe, Float8BlockScaling):
recipe_supported, unsupported_reason = check_fp8_block_scaling_support()
elif isinstance(recipe, MXFP8BlockScaling):
recipe_supported, unsupported_reason = check_mxfp8_support()
assert recipe_supported, unsupported_reason
def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args."""
if check_mxfp8_support()[0]:
return MXFP8BlockScaling()
if get_device_compute_capability() >= (12, 0):
# This is a temporary restriction until MXFP8 is supported for all gemm layouts.
return Float8CurrentScaling()
return DelayedScaling()
def get_default_recipe() -> Recipe:
"""Returns the default training recipe based on available device."""
return get_default_fp8_recipe()
def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return torch.float8_e4m3fn
return torch.float8_e5m2
def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return tex.DType.kFloat8E4M3
return tex.DType.kFloat8E5M2
def get_fp4_te_dtype(fp4_recipe: Recipe) -> tex.DType:
"""Get fp4 data type according to recipe and tensor"""
if fp4_recipe.fp4_format == Format.E2M1:
return tex.DType.kFloat4E2M1
raise ValueError(f"Unsupported FP4 format: {fp4_recipe.fp4_format}")
def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
"""Get max representible FP8 value."""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return Format.E4M3.value.max_fwd
return Format.E5M2.value.max_fwd
def is_fp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]:
"""
Determine if FP8 support is available for the delayed
scaling and per tensor current scaling recipe.
Parameters
----------
return_reason : bool, optional
If ``False`` (default), return only a boolean indicating availability.
If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides
a human-readable explanation when required support is not available. The reason
will be an empty string if support for FP8 is available.
"""
if return_reason:
return check_fp8_support()
return check_fp8_support()[0]
def is_mxfp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]:
"""
Determine if support is available for the MXFP8 recipe.
Parameters
----------
return_reason : bool, optional
If ``False`` (default), return only a boolean indicating availability.
If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides
a human-readable explanation when required support is not available. The reason
will be an empty string if support for MXFP8 is available.
"""
if return_reason:
return check_mxfp8_support()
return check_mxfp8_support()[0]
def is_fp8_block_scaling_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]:
"""
Determine if support is available for the FP8 block scaling recipe.
Parameters
----------
return_reason : bool, optional
If ``False`` (default), return only a boolean indicating availability.
If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides
a human-readable explanation when required support is not available. The reason
will be an empty string if support for FP8 block scaling is available.
"""
if return_reason:
return check_fp8_block_scaling_support()
return check_fp8_block_scaling_support()[0]
def is_nvfp4_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]:
"""
Determine if support is available for the NVFP4 recipe.
Parameters
----------
return_reason : bool, optional
If ``False`` (default), return only a boolean indicating availability.
If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides
a human-readable explanation when required support is not available. The reason
will be an empty string if support for NVFP4 is available.
"""
if return_reason:
return check_nvfp4_support()
return check_nvfp4_support()[0]
class FP8GlobalStateManager:
"""Class to keep track of and manipulate the global
FP8 state at different stages of execution.
"""
FP8_ENABLED = False
FP8_CALIBRATION = False
FP8_RECIPE = None
FP8_DISTRIBUTED_GROUP = None
FP8_PARAMETERS = False
HIGH_PRECISION_INIT_VAL = False
IS_FIRST_FP8_MODULE = False
FP8_GRAPH_CAPTURING = False
AUTOCAST_DEPTH = 0
global_amax_buffer = {}
global_amax_history_buffer = {}
global_scale_buffer = {}
fp8_tensors_recompute_buffer = []
fp8_available = None
reason_for_no_fp8 = ""
autocast_arguments = {}
skip_fp8_weight_update_tensor = None
mxfp8_available = None
reason_for_no_mxfp8 = ""
fp8_block_scaling_available = None
reason_for_no_fp8_block_scaling = None
nvfp4_available = None
reason_for_no_nvfp4 = ""
@classmethod
def reset(cls) -> None:
"""Reset the global state"""
cls.FP8_ENABLED = False
cls.FP8_CALIBRATION = False
cls.FP8_RECIPE = None
cls.FP8_DISTRIBUTED_GROUP = None
cls.FP8_PARAMETERS = False
cls.HIGH_PRECISION_INIT_VAL = False
cls.IS_FIRST_FP8_MODULE = False
cls.FP8_GRAPH_CAPTURING = False
cls.AUTOCAST_DEPTH = 0
cls.global_amax_buffer = {}
cls.global_amax_history_buffer = {}
cls.global_scale_buffer = {}
cls.fp8_tensors_recompute_buffer = []
cls.fp8_available = None
cls.reason_for_no_fp8 = ""
cls.autocast_arguments = {}
cls.skip_fp8_weight_update_tensor = None
cls.mxfp8_available = None
cls.reason_for_no_mxfp8 = ""
cls.fp8_block_scaling_available = None
cls.reason_for_no_fp8_block_scaling = ""
@classmethod
def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None:
"""`skip_fp8_weight_update_tensor` inplace setter."""
if cls.skip_fp8_weight_update_tensor is None:
cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda")
cls.skip_fp8_weight_update_tensor.fill_(skip)
@classmethod
def get_skip_fp8_weight_update_tensor(cls) -> None:
"""`skip_fp8_weight_update_tensor` getter."""
return cls.skip_fp8_weight_update_tensor
@classmethod
def is_fp8_available(cls) -> Tuple[bool, str]:
"""Return if fp8 support is available"""
return check_fp8_support()
@classmethod
def is_mxfp8_available(cls) -> Tuple[bool, str]:
"""Return if MXFP8/current scaling support is available."""
return check_mxfp8_support()
@classmethod
def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]:
"""Return if Float8 block scaling support is available."""
return check_fp8_block_scaling_support()
@classmethod
def is_nvfp4_available(cls) -> Tuple[bool, str]:
"""Return if NVFP4 support is available."""
return check_nvfp4_support()
@staticmethod
def get_meta_tensor_key(forward: bool = True) -> str:
"""Returns scaling key in `fp8_meta`."""
if forward:
return "scaling_fwd"
return "scaling_bwd"
@staticmethod
def get_fwd_bwd_key(forward: bool = True) -> str:
"""Convert bool `forward` to string."""
return "forward" if forward else "backward"
@classmethod
def get_buffer_info(cls) -> str:
"""
Returns a key for `fp8_meta` that stores the module's index
in the global buffers along with autocast information.
"""
return "buffer_index_and_autocast_key"
@classmethod
def get_key_in_buffer(
cls,
forward: bool,
fp8_recipe: Recipe,
fp8_group: dist_group_type,
) -> str:
"""Returns a key into the global FP8 buffers."""
autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
fwd_bwd_key = cls.get_fwd_bwd_key(forward)
return f"{fwd_bwd_key}_{autocast_key}"
@classmethod
def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]:
"""Splits buffer key into relevant parts."""
forward, autocast_key = key.split("_", 1)
forward = forward == "forward"
return forward, autocast_key
@classmethod
def add_fp8_tensors_to_global_buffer(
cls,
fp8_meta: Dict[str, Any],
) -> None:
"""
Delayed scaling only.
The amax reduction process happens completely outside the FP8 modules.
To participate in the reduction, the only role played by a module is
to call this function in order to append it's FP8 tensor into a global
buffer. There are 5 global buffers maintained, one each for amax, amax
history, scale, scale-inverse, and non-weight-mask. Each buffer has
keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix
to indicate the type of FP8 tensor, since the forward and backward
reductions happen separately.
Note: For CG capture, this method is called from the graphed
wrapper. For non CG case, it's called from within the module.
"""
# delayed scaling only function, noop for any other recipe
if not fp8_meta["recipe"].delayed():
return
# Every module must call this function exactly once since
# the amax tensors are static. Ensures that compatibility
# with non-graphed modules is maintained.
index_in_buffer = cls.get_buffer_info() # Same index for fwd/bwd fp8 tensors.
if index_in_buffer in fp8_meta:
return
fp8_meta[index_in_buffer] = []
for forward in (True, False):
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
if fp8_meta_tensor_key not in fp8_meta:
# Handles non-parameter FP8 modules, e.g. DPA.
continue
key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"])
if key not in cls.global_amax_buffer:
cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history]
cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale]
else:
cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0])
cls.global_amax_history_buffer[key].append(
fp8_meta[fp8_meta_tensor_key].amax_history
)
cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale)
fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1)
fp8_meta[index_in_buffer].append(key)
@classmethod
def is_fp8_enabled(cls) -> bool:
"""Is FP8 enabled"""
return cls.FP8_ENABLED
@classmethod
def is_fp8_calibration(cls) -> bool:
"""Is FP8 calibration"""
return cls.FP8_CALIBRATION
@classmethod
def with_fp8_parameters(cls) -> bool:
"""Should the parameters be stored as FP8"""
return cls.FP8_PARAMETERS
@classmethod
def with_high_precision_init_val(cls) -> bool:
"""Should the high precision initial values be stored with FP8 parameters"""
return cls.HIGH_PRECISION_INIT_VAL
@classmethod
def fp8_graph_capturing(cls) -> bool:
"""Is CUDA graph capture under way?"""
return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing()
@classmethod
def is_first_fp8_module(cls):
"""Returns `True` only the first time when called multiple
times from within the same `autocast` context.
"""
tmp = cls.IS_FIRST_FP8_MODULE
cls.IS_FIRST_FP8_MODULE = False
return tmp
@classmethod
def get_fp8_recipe(cls) -> Recipe:
"""Return the fp8 recipe"""
if cls.FP8_RECIPE is not None:
return cls.FP8_RECIPE
return get_default_fp8_recipe()
@classmethod
def get_fp8_group(cls) -> Union[dist_group_type, None]:
"""Return the fp8 group for scale/amax comm"""
return cls.FP8_DISTRIBUTED_GROUP
@classmethod
def get_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]:
"""FP8 autocast state getter"""
return (
cls.FP8_ENABLED,
cls.FP8_CALIBRATION,
cls.FP8_RECIPE,
cls.FP8_DISTRIBUTED_GROUP,
cls.IS_FIRST_FP8_MODULE,
cls.FP8_GRAPH_CAPTURING,
)
@classmethod
def set_autocast_state(
cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool]
) -> None:
"""FP8 autocast state setter"""
(
cls.FP8_ENABLED,
cls.FP8_CALIBRATION,
cls.FP8_RECIPE,
cls.FP8_DISTRIBUTED_GROUP,
cls.IS_FIRST_FP8_MODULE,
cls.FP8_GRAPH_CAPTURING,
) = fp8_state
@staticmethod
def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None:
"""Reduce tensor across given group."""
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=group,
async_op=False,
)
@classmethod
def reduce_and_update_fp8_tensors(
cls,
forward: bool = True,
) -> None:
"""Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer."""
# global_amax_buffer should only be non-empty for fp8 delayed scaling
for buffer_key, amax_buffer in cls.global_amax_buffer.items():
# Check for forward or backward reduction.
fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key)
if fwd_update != forward:
continue
if len(amax_buffer) == 0:
continue
# Retrieve autocast specific args and concat amaxes.
recipe, group = cls.autocast_arguments[autocast_key]
contiguous_amax = torch.cat(amax_buffer)
# Reduction.
if (
recipe.reduce_amax
and torch.distributed.is_initialized()
and torch.distributed.get_world_size(group=group) > 1
):
cls.reduce_tensor_across_group_op_max(contiguous_amax, group)
# Amax and scale update.
unfused_update = (
bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0")))
or callable(recipe.amax_compute_algo)
or callable(recipe.scaling_factor_compute_algo)
)
if not unfused_update:
tex.fused_amax_and_scale_update_after_reduction(
contiguous_amax,
cls.global_amax_history_buffer[buffer_key],
cls.global_scale_buffer[buffer_key],
recipe.amax_compute_algo,
get_fp8_te_dtype(recipe, forward),
recipe.margin,
)
else:
split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer])
for amax_history, scale in zip(
cls.global_amax_history_buffer[buffer_key],
cls.global_scale_buffer[buffer_key],
):
_amax_and_scale_update(
amax_history, scale, get_fp8_max(recipe, forward), recipe
)
@classmethod
def get_unique_autocast_key(
cls,
recipe: Optional[Recipe] = None,
group: Optional[dist_group_type] = None,
):
"""
For FP8, each autocast can be uniquely identified by the recipe and fp8 group.
Safely using `hash` as we never cross checkpoint boundaries.
"""
return f"{str(recipe)}:{hash(group)}"
@classmethod
def autocast_enter(
cls,
enabled: bool = False,
calibrating: bool = False,
fp8_recipe: Optional[Recipe] = None,
fp8_group: Optional[dist_group_type] = None,
_graph: bool = False,
) -> None:
"""Set state and tracking variables for entry into FP8 region."""
fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group)
cls.FP8_ENABLED = enabled
cls.FP8_CALIBRATION = calibrating
cls.FP8_RECIPE = fp8_recipe
cls.FP8_DISTRIBUTED_GROUP = fp8_group
cls.FP8_GRAPH_CAPTURING = _graph
if cls.AUTOCAST_DEPTH == 0:
cls.IS_FIRST_FP8_MODULE = True
cls.AUTOCAST_DEPTH += 1
if enabled:
fp8_available, reason_for_no_fp8 = cls.is_fp8_available()
assert fp8_available, reason_for_no_fp8
if isinstance(fp8_recipe, MXFP8BlockScaling):
mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available()
assert mxfp8_available, reason_for_no_mxfp8
if isinstance(fp8_recipe, Float8BlockScaling):
fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available()
assert fp8_block_available, reason_for_no_fp8_block
if isinstance(fp8_recipe, NVFP4BlockScaling):
nvfp4_available, reason_for_no_nvfp4 = cls.is_nvfp4_available()
assert nvfp4_available, reason_for_no_nvfp4
@classmethod
def autocast_exit(cls, enabled: bool, _graph: bool) -> None:
"""Set state and tracking variables for exit from FP8 region."""
cls.AUTOCAST_DEPTH -= 1
# Reduce only the non-FP8 weight modules here.
# FP8 weight modules are reduced at the end of the optimizer
# step after the weight amax is populated.
if enabled and cls.AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
# delayed scaling only function, for other recipes (current scaling with any granularity),
# this is noop for other recipes because cls.global_amax_buffer is empty list
cls.reduce_and_update_fp8_tensors(forward=True)
@classmethod
def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
"""Copy the scaling factors and amaxes for recompute forward phase
to ensure both forward steps are numerically same.
"""
# delayed scaling only function, noop for any other recipe
if not fp8_meta["recipe"].delayed():
return
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
to_copy = [
fp8_meta["scaling_fwd"].amax_history.clone(),
fp8_meta["scaling_fwd"].scale.clone(),
]
if buffer_position_key in fp8_meta:
cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy)
else:
if len(cls.fp8_tensors_recompute_buffer) == 0:
cls.fp8_tensors_recompute_buffer = [deque()]
else:
cls.fp8_tensors_recompute_buffer.append(deque())
cls.fp8_tensors_recompute_buffer[-1].append(to_copy)
fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1
@classmethod
def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
"""Switch to the copied scaling factors and amaxes from phase
1 forward for indentical numerical outputs.
"""
# delayed scaling only function, noop for any other recipe
if not fp8_meta["recipe"].delayed():
return
# Store updated amaxes and scales from phase 1 post forward.
fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history.clone()
fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale.clone()
# Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft()
# Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0])
fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1])
@staticmethod
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""
# delayed scaling only function, noop for any other recipe
if not fp8_meta["recipe"].delayed():
return
fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"])
fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"])
@contextmanager
def fp8_model_init(
enabled: bool = True,
recipe: Optional[Recipe] = None,
preserve_high_precision_init_val: bool = False,
) -> None:
"""
.. warning::
fp8_model_init is deprecated and will be removed in a future release. Use
quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...) instead.
"""
warnings.warn(
"fp8_model_init is deprecated and will be removed in a future release. "
"Use quantized_model_init("
"enabled=..., recipe=..., preserve_high_precision_init_val=...) instead.",
category=DeprecationWarning,
stacklevel=2,
)
# Call new implementation.
with quantized_model_init(
enabled=enabled,
recipe=recipe,
preserve_high_precision_init_val=preserve_high_precision_init_val,
):
yield
@contextmanager
def quantized_model_init(
enabled: bool = True,
recipe: Optional[Recipe] = None,
preserve_high_precision_init_val: bool = False,
) -> None:
"""
Context manager for initialization of quantized parameters.
Example usage:
.. code-block:: python
with quantized_model_init(enabled=True):
model = transformer_engine.pytorch.Linear(768, 768)
# Preserving high precision initial value to initialize master weight
with quantized_model_init(enabled=True, preserve_high_precision_init_val=True):
model = transformer_engine.pytorch.Linear(768, 768)
master_weight = model.weight.get_high_precision_init_val()
model.weight.clear_high_precision_init_val()
Parameters
----------
enabled: bool, default = `True`
when enabled, Transformer Engine modules created inside this `quantized_model_init`
region will hold only quantized copies of its parameters, as opposed to the default
behavior where both higher precision and quantized copies are present. Setting this
option to `True` may result in lower memory consumption and is especially
useful for scenarios like:
* full model training using optimizer with master weights, where the high
precision copies of weights are already present in the optimizer.
* inference, where only the quantized copies of the parameters are used.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
recipe: transformer_engine.common.recipe.Recipe, default = `None`
Recipe used to create the parameters. If left to None, it uses the default recipe.
preserve_high_precision_init_val: bool, default = `False`
when enabled, store the high precision tensor used to initialize quantized parameters
in CPU memory, and add two function attributes named `get_high_precision_init_val()`
and `clear_high_precision_init_val()` to quantized parameters to get/clear this high
precision tensor. The purpose is that users can use this high-precision copy
to initialize master weights, avoiding the loss of precision that can occur when
using quantized parameters directly. Note that after the master weights are initialized,
users should call `clear_high_precision_init_val()` to release this CPU memory.
This functionality is *EXPERIMENTAL*.
"""
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
_fp8_recipe = FP8GlobalStateManager.FP8_RECIPE
_high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL
FP8GlobalStateManager.FP8_PARAMETERS = enabled
FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe
FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val
try:
yield
finally:
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters
FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe
FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val
@contextmanager
def fp8_autocast(
enabled: bool = True,
calibrating: bool = False,
fp8_recipe: Optional[Recipe] = None,
fp8_group: Optional[dist_group_type] = None,
_graph: bool = False,
) -> None:
"""
.. warning::
fp8_autocast is deprecated and will be removed in a future release.
Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead.
"""
warnings.warn(
"fp8_autocast is deprecated and will be removed in a future release. "
"Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead.",
category=DeprecationWarning,
stacklevel=2,
)
# Call new implementation.
with autocast(
enabled=enabled,
calibrating=calibrating,
recipe=fp8_recipe,
amax_reduction_group=fp8_group,
_graph=_graph,
):
yield
@contextmanager
def autocast(
enabled: bool = True,
calibrating: bool = False,
recipe: Optional["Recipe"] = None,
amax_reduction_group: Optional["dist_group_type"] = None,
_graph: bool = False,
) -> None:
"""
Context manager for quantization schemes like FP8 or FP4.
.. code-block:: python
with autocast(enabled=True):
out = model(inp)
.. note::
Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors
with shapes where both dimensions are divisible by 16. In terms of the input to the full
Transformer network, this typically requires padding sequence length to be multiple of 16.
.. note::
When :attr:`recipe.reduce_amax==True`, any module must not be invoked more than once
inside a single `autocast` region. This is unsupported behavior because the amax
reduction is handled during the exit of the `autocast` context. Calling the same
module more than once inside an `autocast` region overrides the amax tensors
before reduction can occur.
Parameters
----------
enabled: bool, default = `True`
whether or not to enable low precision quantization (FP8/FP4).
calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale
data of quantized tensors even when executing without quantization enabled.
This is useful for saving an inference ready checkpoint while training
using a higher precision.
recipe: recipe.Recipe, default = `None`
recipe used for low precision quantization.
amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the quantized tensors
are reduced at the end of each training step.
"""
if enabled:
check_recipe_support(recipe)
# Save current state so we always restore it on exit.
fp8_state = FP8GlobalStateManager.get_autocast_state()
FP8GlobalStateManager.autocast_enter(
enabled=enabled,
calibrating=calibrating,
fp8_recipe=recipe,
fp8_group=amax_reduction_group,
_graph=_graph,
)
try:
yield
finally:
FP8GlobalStateManager.set_autocast_state(fp8_state)
FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph)
def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
"""Update amax history and set next amax to zero."""
if amax_history.shape[0] > 1:
new_amax_history = torch.roll(amax_history, -1, 0)
amax_history.copy_(new_amax_history)
amax_history[0].fill_(0.0)
return amax_history
@torch.jit.script
def _default_get_amax_and_update_history(
amax_history: torch.Tensor,
amax_compute_algo: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Default function to obtain amax from history."""
if amax_compute_algo == "max":
amax = torch.max(amax_history, dim=0).values
else: # amax_compute_algo == "most_recent"
amax = amax_history[0].clone()
amax_history = _update_amax_history(amax_history)
return amax_history, amax
@jit_fuser
def _default_sf_compute(
amax: torch.Tensor,
scale: torch.Tensor,
fp8_max: float,
margin: int,
_fp32_max: float = torch.finfo(torch.float32).max, # finfo not available in jitter
) -> torch.Tensor:
"""Default function to convert amax to scaling factor.
Computing the scaling factor requires consideration of the following scenarios:
1. amax == 0:
No action is possible, set scale to the previous scale (or 1).
2. 0 < amax < tiny_amax
The amax is too tiny that the scale becomes infinite in FP32.
Set scale = FP32_max
3. tiny_amax <= amax < FP32_max:
Set scale = FP8_max (or scaled_max) / amax
4. When amax == inf or amax == nan:
No action is possible, set scale to the previous scale (or 1).
"""
sf = (fp8_max / amax) / (2**margin)
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf)
scale.copy_(sf)
return scale
def _compute_amax_and_update_history(
amax_history: torch.Tensor,
amax_compute_algo: Union[Callable, str],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Obtain the amax from the history."""
if callable(amax_compute_algo):
amax = amax_compute_algo(amax_history)
amax_history = _update_amax_history(amax_history)
return amax_history, amax
return _default_get_amax_and_update_history(
amax_history,
amax_compute_algo,
)
def _compute_scaling_factor(
amax: torch.Tensor,
scale: torch.Tensor,
fp8_max: float,
recipe: DelayedScaling,
) -> torch.Tensor:
"""Convert amax to scaling factor."""
if recipe.scaling_factor_compute_algo is None:
return _default_sf_compute(
amax,
scale,
fp8_max,
recipe.margin,
)
return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe)
def _amax_and_scale_update(
amax_history: torch.Tensor,
scale: torch.Tensor,
fp8_max: float,
recipe: DelayedScaling,
) -> None:
"""Updates FP8 meta tensors."""
new_amax_history, amax = _compute_amax_and_update_history(
amax_history,
recipe.amax_compute_algo,
)
new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe)
scale.copy_(new_scale)
amax_history.copy_(new_amax_history)
def split_and_copy(
buffer: torch.Tensor,
outputs: List[torch.Tensor],
chunk_sizes: List[int],
) -> None:
"""Split `buffer` by `chunk_sizes` and copy into `outputs`."""
splits = buffer.split(chunk_sizes)
torch._foreach_copy_(outputs, splits)
class RecipeState(abc.ABC):
"""Configuration and state for a quantization recipe.
This is a builder class for quantizers, which are in turn builder
classes for quantized tensors.
This class may pack together the state for multiple quantizers,
which is helpful for applying fused kernels with less overhead.
"""
@staticmethod
def create(
recipe: Recipe,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> RecipeState:
"""Factory method to create the state for a quantization recipe
Parameters
----------
recipe: Recipe
Quantization recipe.
mode: {"forward", "backward"}
Training stage where quantization will be performed.
num_quantizers: int, default = 1
Number of quantizers to create state for.
device: torch.device, default = default CUDA device
Device for quantized tensors.
Returns
-------
RecipeState:
Quantization recipe state.
"""
cls = None
if recipe.delayed():
cls = DelayedScalingRecipeState
elif recipe.mxfp8():
cls = MXFP8BlockScalingRecipeState
elif recipe.float8_current_scaling():
cls = Float8CurrentScalingRecipeState
elif recipe.float8_block_scaling():
cls = Float8BlockScalingRecipeState
elif recipe.nvfp4():
cls = NVFP4BlockScalingRecipeState
elif recipe.custom():
cls = CustomRecipeState
else:
raise ValueError(f"{recipe.__class__.__name__} is not supported")
return cls(
recipe,
mode=mode,
num_quantizers=num_quantizers,
device=device,
)
@abc.abstractmethod
def make_quantizers(self) -> list:
"""Convert recipe state to quantizers.
Quantizers are builder classes for quantized tensors. They are
typically used to convert a high-precision tensor (e.g. in
FP32 or BF16) into a quantized tensor (e.g. in FP8).
"""
class DelayedScalingRecipeState(RecipeState):
"""State for FP8 quantization with per-tensor delayed scaling.
Delayed scaling recipe requires a scaling factor (applied when
casting to FP8) and a history of max-abs values ("amax") from
recent FP8 casts for updating the scaling factor. The scale update
is handled externally by `FP8GlobalStateManager`.
"""
recipe: DelayedScaling
mode: str
dtype: tex.DType
scale: torch.Tensor
amax_history: torch.Tensor
def __init__(
self,
recipe: DelayedScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.dtype = get_fp8_te_dtype(recipe, mode == "forward")
# Allocate buffers
if device is None:
device = torch.device("cuda")
self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device)
self.amax_history = torch.zeros(
recipe.amax_history_len,
num_quantizers,
dtype=torch.float32,
device=device,
)
def make_quantizers(self) -> list:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from .tensor.float8_tensor import Float8Quantizer
return [
Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype)
for i in range(self.num_quantizers)
]
class Float8CurrentScalingRecipeState(RecipeState):
"""Configuration for Per-tensor current scaling quantization.
Per-tensor current quantization does not require state.
"""
recipe: Float8CurrentScaling
mode: str
dtype: tex.DType
device: torch.device
def __init__(
self,
recipe: Float8CurrentScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.dtype = get_fp8_te_dtype(recipe, mode == "forward")
# Allocate buffers
if device is None:
device = torch.device("cuda")
self.device = device
def make_quantizers(self) -> list:
from .tensor.float8_tensor import Float8CurrentScalingQuantizer
return [
Float8CurrentScalingQuantizer(
self.dtype, device=self.device, force_pow_2_scales=self.recipe.use_power_2_scales
)
for i in range(self.num_quantizers)
]
class MXFP8BlockScalingRecipeState(RecipeState):
"""Configuration for MXFP8 quantization.
MXFP8 quantization does not require state.
"""
recipe: MXFP8BlockScaling
mode: str
dtype: tex.DType
def __init__(
self,
recipe: MXFP8BlockScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.dtype = get_fp8_te_dtype(recipe, mode == "forward")
# Allocate buffers
if device is None:
device = torch.device("cuda")
def make_quantizers(self) -> list:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from .tensor.mxfp8_tensor import MXFP8Quantizer
return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)]
class Float8BlockScalingRecipeState(RecipeState):
"""Configuration for Float8BlockScaling quantization.
Float8BlockScaling quantization does not require state,
but different quantizers use different modes.
"""
recipe: Float8BlockScaling
mode: str
qx_dtype: tex.DType
qw_dtype: tex.DType
qgrad_dtype: tex.DType
def __init__(
self,
recipe: Float8BlockScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.qx_dtype = get_fp8_te_dtype(recipe, True)
self.qw_dtype = get_fp8_te_dtype(recipe, True)
self.qgrad_dtype = get_fp8_te_dtype(recipe, False)
# Allocate buffers
if device is None:
device = torch.device("cuda")
self.device = device
def make_quantizers(self) -> list:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
if self.mode == "forward":
# The index convention (coming from base.py set_meta_tensor)
# is somewhat awkward, and doesn't play nicely with QuantizeOp,
# which is not associated with a GEMM.
assert self.num_quantizers % 3 == 0 # x, w, output per gemm
return list(
itertools.chain.from_iterable(
[
[
Float8BlockQuantizer(
fp8_dtype=self.qx_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
block_scaling_dim=self.recipe.x_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qw_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale,
block_scaling_dim=self.recipe.w_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qx_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
block_scaling_dim=self.recipe.x_block_scaling_dim,
),
]
for _ in range(self.num_quantizers // 3)
]
)
)
assert self.mode == "backward", f"Unexpected mode {self.mode}"
assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm
return list(
itertools.chain.from_iterable(
[
[
Float8BlockQuantizer(
fp8_dtype=self.qgrad_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
block_scaling_dim=self.recipe.grad_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qgrad_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
block_scaling_dim=self.recipe.grad_block_scaling_dim,
),
]
for _ in range(self.num_quantizers // 2)
]
)
)
class NVFP4BlockScalingRecipeState(RecipeState):
"""Configuration for NVFP4 quantization.
NVFP4 quantization does not require state.
"""
recipe: NVFP4BlockScaling
mode: str
dtype: tex.DType
def __init__(
self,
recipe: NVFP4BlockScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.dtype = get_fp4_te_dtype(recipe)
# Allocate buffers
if device is None:
device = torch.device("cuda")
def make_quantizers(self) -> list:
from .tensor.nvfp4_tensor import NVFP4Quantizer
# The index convention (coming from base.py set_meta_tensor)
# is somewhat awkward. It assumes forward quantizers are
# ordered [input, weight, output, ...] and backward quantizers
# are ordered [grad_output, grad_input, ...]. This doesn't
# play nicely with fusible ops: Linear op doesn't own output
# or grad input quantizers, Quantize op only owns input and
# grad output quantizers.
if self.mode == "forward":
def _make_quantizer(idx: int) -> NVFP4Quantizer:
qparams = (
self.recipe.fp4_quant_fwd_weight
if idx % 3 == 1
else self.recipe.fp4_quant_fwd_inp
)
return NVFP4Quantizer(
fp4_dtype=self.dtype,
rowwise=True,
columnwise=True,
with_rht=qparams.random_hadamard_transform,
with_post_rht_amax=qparams.random_hadamard_transform,
with_2d_quantization=qparams.fp4_2d_quantization,
stochastic_rounding=qparams.stochastic_rounding,
)
return [_make_quantizer(idx) for idx in range(self.num_quantizers)]
if self.mode == "backward":
return [
NVFP4Quantizer(
fp4_dtype=self.dtype,
rowwise=True,
columnwise=True,
with_rht=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform,
with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform,
with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization,
stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding,
)
for _ in range(self.num_quantizers)
]
raise RuntimeError(f"Unexpected recipe mode ({self.mode})")
class CustomRecipeState(RecipeState):
"""State for CustomRecipe: produce quantizers per tensor."""
recipe: CustomRecipe
mode: str
num_quantizers: int
device: Optional[torch.device]
def __init__(
self,
recipe: CustomRecipe,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
if device is None:
device = torch.device("cuda")
self.device = device
if getattr(recipe, "qfactory", None) is None:
raise ValueError("CustomRecipe requires `qfactory`.")
def make_quantizers(self) -> list:
qfactory = self.recipe.qfactory
out = []
# TODO(negvet): make_quantizers() should take roles from the operation
# Hardcode linear-specific roles for now
roles: List[str]
if self.mode == "forward":
roles = [
("linear_input", "linear_weight", "linear_output")[i % 3]
for i in range(self.num_quantizers)
]
elif self.mode == "backward":
roles = [
("linear_grad_output", "linear_grad_input")[i % 2]
for i in range(self.num_quantizers)
]
else:
roles = ["unknown"] * self.num_quantizers
for i in range(self.num_quantizers):
# Get quantizer from the user defined factory
quantizer = qfactory(roles[i])
out.append(quantizer)
return out
......@@ -45,7 +45,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers
from build_tools.utils import copy_common_headers, min_python_version_str
from build_tools.te_version import te_version
from build_tools.pytorch import (
setup_pytorch_extension,
......@@ -152,6 +152,7 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand},
python_requires=f">={min_python_version_str()}",
install_requires=install_requirements(),
tests_require=test_requirements(),
)
......
......@@ -6,12 +6,42 @@
import torch
from .quantized_tensor import QuantizedTensor, Quantizer
from .quantized_tensor import (
QuantizedTensorStorage,
QuantizedTensor,
Quantizer,
prepare_for_saving,
restore_from_saved,
)
from .storage.float8_tensor_storage import Float8TensorStorage
from .storage.mxfp8_tensor_storage import MXFP8TensorStorage
from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .storage.nvfp4_tensor_storage import NVFP4TensorStorage
from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
from .nvfp4_tensor import NVFP4Tensor, NVFP4Quantizer
from .utils import cast_master_weights_to_fp8, replace_raw_data
__all__ = [
"QuantizedTensor",
"Quantizer",
"Float8Quantizer",
"Float8CurrentScalingQuantizer",
"MXFP8Quantizer",
"Float8BlockQuantizer",
"NVFP4Quantizer",
"QuantizedTensorStorage",
"Float8TensorStorage",
"MXFP8TensorStorage",
"Float8BlockwiseQTensorStorage",
"NVFP4TensorStorage",
"QuantizedTensor",
"Float8Tensor",
"MXFP8Tensor",
"Float8BlockwiseQTensor",
"NVFP4Tensor",
"prepare_for_saving",
"restore_from_saved",
]
......@@ -48,21 +78,16 @@ def get_all_tensor_types():
"""
Get all tensor-like types that can be used in TE.
"""
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
)
all_tensor_types = [
torch.Tensor,
torch.nn.Parameter,
Float8Tensor,
Float8TensorBase,
Float8TensorStorage,
MXFP8Tensor,
MXFP8TensorBase,
MXFP8TensorStorage,
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
Float8BlockwiseQTensorStorage,
NVFP4Tensor,
NVFP4TensorStorage,
]
return all_tensor_types
......@@ -14,8 +14,12 @@ from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
from ..utils import devices_match, round_up_to_nearest_multiple
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
......@@ -104,6 +108,10 @@ class Float8BlockQuantizer(Quantizer):
dst._fp8_dtype = self.dtype
return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]:
"""Calculate the shape of the scaling tensor for blockwise quantization.
......@@ -273,7 +281,7 @@ class Float8BlockQuantizer(Quantizer):
return Float8BlockScaling
class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
"""Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
The tensor presents as having a standard, higher-precision dtype,
......@@ -298,7 +306,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
holds configuration about quantization and dequantization modes.
"""
# NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorBase with positional args,
# NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorStorage with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__(
cls,
......@@ -337,15 +345,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
f" data_format={self._data_format}"
)
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
assert self._quantizer is not None
return self._quantizer
def quantize_(
self,
tensor: torch.Tensor,
......@@ -364,8 +363,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
"""
if isinstance(tensor, QuantizedTensor):
return self.quantize_(tensor.dequantize())
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
return super().quantize_(tensor, noop_flag=noop_flag)
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
......@@ -408,6 +406,21 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
# pylint: disable=missing-function-docstring
return _ReshapeFunc.apply(self, shape)
def untyped_storage(self) -> torch.UntypedStorage:
"""Return the underlying UntypedStorage of the FP8 data.
Note that FP8 block-scaled tensor may involve multiple
buffers: row-wise FP8 data, row-wise scales, column-wise FP8
data, column-wise scales. The UntypedStorage of the row-wise
FP8 data is returned if it exists, and otherwise the
UntypedStorage of the column-wise FP8 data.
"""
data = self._rowwise_data if self._rowwise_data is not None else self._columnwise_data
if data is not None:
return data.untyped_storage()
return torch.UntypedStorage(0, device=self.device)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
......@@ -432,6 +445,19 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
)
return Float8BlockwiseQTensor.make_like(tensor)
# record stream op
if func == torch.ops.aten.record_stream.default:
qt, stream = args
for t in (
qt._rowwise_data,
qt._columnwise_data,
qt._rowwise_scale_inv,
qt._columnwise_scale_inv,
):
if t is not None and t.is_cuda:
t.record_stream(stream)
return None
# Default case
return super().__torch_dispatch__(func, types, args, kwargs)
......
......@@ -13,8 +13,12 @@ from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe
from ..utils import canonicalize_process_group, devices_match
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func
from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
from ..constants import dist_group_type
from transformer_engine.pytorch.fp8 import int8_simulation_fp8_tensorwise
......@@ -90,6 +94,10 @@ class Float8Quantizer(Quantizer):
return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def make_empty(
self,
shape: Iterable[int],
......@@ -148,7 +156,7 @@ class Float8Quantizer(Quantizer):
torch.float8_e5m2fnuz,
]
if internal:
return Float8TensorBase(
return Float8TensorStorage(
data=data,
fp8_scale_inv=1 / self.scale,
fp8_dtype=self.dtype,
......@@ -216,6 +224,8 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax: torch.Tensor
"""FP8 datatype"""
dtype: TE_DType
"""amax update options"""
use_existing_amax: bool
"""amax reduction options"""
with_amax_reduction: bool
amax_reduction_group: Optional[dist_group_type]
......@@ -230,6 +240,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
*,
rowwise: bool = True,
columnwise: bool = True,
use_existing_amax: bool = False,
with_amax_reduction: bool = False,
amax_reduction_group: Optional[dist_group_type] = None,
force_pow_2_scales: bool = False,
......@@ -239,6 +250,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
self.scale = torch.empty(1, dtype=torch.float32, device=device)
self.amax = torch.empty(1, dtype=torch.float32, device=device)
self.dtype = tex.DType.kInt8 if int8_simulation_fp8_tensorwise else fp8_dtype
self.use_existing_amax = use_existing_amax
self.with_amax_reduction = with_amax_reduction
self.amax_reduction_group = amax_reduction_group
self.force_pow_2_scales = force_pow_2_scales
......@@ -268,6 +280,10 @@ class Float8CurrentScalingQuantizer(Quantizer):
return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def make_empty(
self,
shape: Iterable[int],
......@@ -330,7 +346,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
torch.float8_e5m2fnuz,
]
if internal:
return Float8TensorBase(
return Float8TensorStorage(
data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device),
fp8_dtype=self.dtype,
......@@ -385,7 +401,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
return True
class Float8Tensor(Float8TensorBase, QuantizedTensor):
class Float8Tensor(Float8TensorStorage, QuantizedTensor):
"""Experimental tensor class with FP8 data
The tensor presents as having a standard, higher-precision dtype,
......@@ -440,19 +456,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
return _FromFloat8Func.apply(self, dtype)
return _FromFloat8Func.forward(None, self, dtype)
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if self._quantizer is not None:
return self._quantizer
# Now the quantizer for Float8Tensor can be not just Float8Quantizer (delayed scaling)
raise ValueError(
"Float8Tensor's quantizer is None, cannot get a quantizer from Float8Tensor variable"
)
def quantize_(
self,
tensor: torch.Tensor,
......@@ -471,8 +474,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""
if isinstance(tensor, QuantizedTensor):
return self.quantize_(tensor.dequantize(), noop_flag=noop_flag)
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
return super().quantize_(tensor, noop_flag=noop_flag)
def detach(self) -> Float8Tensor:
# pylint: disable=missing-function-docstring
......
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Tensor class with FP8 data"""
"""Tensor class with MXFP8 data"""
from __future__ import annotations
from collections.abc import Iterable
import math
......@@ -16,8 +16,12 @@ from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe
from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match, round_up_to_nearest_multiple
from ._internal.mxfp8_tensor_base import MXFP8TensorBase, _FromMXFP8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func
from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
aten = torch.ops.aten
......@@ -67,6 +71,10 @@ class MXFP8Quantizer(Quantizer):
return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized"""
if inp.ndim < 2:
......@@ -161,14 +169,14 @@ class MXFP8Quantizer(Quantizer):
data, scale_inv = torch.ops.tex.mxfp8_quantize(tensor)
return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32)
def onnx_dequantize(self, tensor: Union[MXFP8TensorBase, MXFP8Tensor]) -> torch.Tensor:
def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> torch.Tensor:
return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return MXFP8BlockScaling
class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
"""Experimental tensor class with FP8 data
The tensor presents as having a standard, higher-precision dtype,
......@@ -186,14 +194,13 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
Reciprocal of the scaling factor applied when
casting to FP8, i.e. the scaling factor that must
be applied when casting from FP8 to higher
precision. Can be inferred from fp8_meta if
provided.
precision.
dtype: torch.dtype, default = torch.float32
Nominal tensor datatype.
"""
# NOTE: We reorder the *args so that we can instantiate a MXFP8TensorBase with positional args,
# NOTE: We reorder the *args so that we can instantiate a MXFP8TensorStorage with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__(
cls,
......@@ -237,17 +244,9 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
return _FromMXFP8Func.apply(self, dtype)
return _FromMXFP8Func.forward(None, self, dtype)
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if self._quantizer is not None:
return self._quantizer
return MXFP8Quantizer(
fp8_dtype=self._fp8_dtype,
)
def _build_default_quantizer(self) -> Optional[Quantizer]:
"""Build default quantizer for the tensor"""
return MXFP8Quantizer(fp8_dtype=self._fp8_dtype)
def quantize_(
self,
......@@ -267,8 +266,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
"""
if isinstance(tensor, QuantizedTensor):
return self.quantize_(tensor.dequantize())
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
return super().quantize_(tensor, noop_flag=noop_flag)
def detach(self) -> MXFP8Tensor:
# pylint: disable=missing-function-docstring
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tensor class with NVFP4 data"""
from __future__ import annotations
from collections.abc import Iterable
import math
from typing import Optional, Tuple, Union
import functools
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import NVFP4BlockScaling, Recipe
from ..constants import NVFP4_BLOCK_SCALING_SIZE, dist_group_type
from ..utils import (
canonicalize_process_group,
devices_match,
round_up_to_nearest_multiple,
)
from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
aten = torch.ops.aten
def get_no_random_sign_vector() -> torch.Tensor:
"""Non-random sign vector for Hadamard transform."""
return torch.tensor([1], dtype=torch.float32)
def get_sign_from_vector(vector: torch.Tensor) -> int:
"""Convert sign vector to bitmask.
Used for random Hadamard transform.
"""
mask = 0
for i, v in enumerate(vector):
mask |= (v == -1) << i
return mask
def get_wgrad_sign_vector() -> torch.Tensor:
"""Hard-coded random signs for Hadamard transform.
https://xkcd.com/221/
"""
return torch.tensor(
[1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1],
dtype=torch.float32,
)
def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor:
"""Construct a 16x16 Hadamard matrix."""
assert hadamard_dimension == 16, "Only hadamard dimension 16 is supported."
hadamard_scale = 1 / math.sqrt(hadamard_dimension)
return (
torch.tensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1],
[1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1],
[1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1],
[1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1],
[1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1],
[1, 1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1],
[1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1],
[1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1],
[1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, -1, 1, -1, 1],
[1, 1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1],
[1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1],
[1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1],
[1, -1, 1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1],
[1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1, 1, 1, -1, -1],
[1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1],
],
dtype=torch.float32,
)
* hadamard_scale
)
@functools.lru_cache(maxsize=None)
def get_rht_matrix(with_random_sign_mask: bool) -> torch.Tensor:
"""Construct matrix used in random Hadamard transform."""
hadamard_dimension = 16
if with_random_sign_mask:
signs = get_wgrad_sign_vector()
else:
signs = get_no_random_sign_vector()
sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32)
rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension)
return rht_matrix.to(dtype=torch.bfloat16).cuda()
@functools.lru_cache(maxsize=None)
def get_random_sign_mask_for_rht(with_random_sign_mask: bool) -> int:
"""Sign mask for random Hadamard transform."""
if with_random_sign_mask:
return get_sign_from_vector(get_wgrad_sign_vector())
return 0
class NVFP4Quantizer(Quantizer):
"""Builder class for NVFP4 tensors with NV block scaling"""
dtype: TE_DType
"""Random Hadamard Transform"""
with_rht: bool
with_post_rht_amax: bool
"""amax reduction options"""
with_amax_reduction: bool
amax_reduction_group: Optional[dist_group_type]
"""2D block scaling, only applicable for weights."""
with_2d_quantization: bool
"""Stochastic rounding, only applicable for gradients."""
stochastic_rounding: bool
"""RHT matrix random sign mask"""
rht_matrix_random_sign_mask_t: int
rht_matrix: torch.Tensor
def __init__(
self,
fp4_dtype: TE_DType = tex.DType.kFloat4E2M1,
rowwise: bool = True,
columnwise: bool = True,
with_amax_reduction: bool = False,
amax_reduction_group: Optional[dist_group_type] = None,
with_rht: bool = False,
with_post_rht_amax: bool = False,
with_2d_quantization: bool = False,
stochastic_rounding: bool = False,
with_random_sign_mask: bool = True,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = fp4_dtype
self.with_rht = with_rht
self.with_post_rht_amax = with_post_rht_amax
self.with_amax_reduction = with_amax_reduction
self.amax_reduction_group = amax_reduction_group
self.with_2d_quantization = with_2d_quantization
self.stochastic_rounding = stochastic_rounding
self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht(with_random_sign_mask)
self.rht_matrix = get_rht_matrix(with_random_sign_mask)
def update_quantized(
self,
src: torch.Tensor,
dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
assert isinstance(dst, NVFP4Tensor), f"Cannot store quantized NVFP4 in {type(dst)} type."
# Make sure input is in expected format
if not devices_match(src.device, dst.device):
src = src.to(device=dst.device)
if not src.is_contiguous():
src = src.contiguous()
# Launch cast kernel
tex.quantize(src, self, dst, noop_flag)
return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized"""
if inp.ndim < 2:
return False
if inp.shape[-1] % NVFP4_BLOCK_SCALING_SIZE != 0:
return False
if math.prod(inp.shape[:-1]) % NVFP4_BLOCK_SCALING_SIZE != 0:
return False
return True
def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]:
"""Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization.
This method determines the shape of the scaling tensor needed for blockwise quantization,
taking into account the input tensor shape and whether columnwise scaling is used.
Parameters
----------
shape : Iterable[int]
Shape of the input tensor to be quantized
columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False)
Returns
-------
Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim)
For NVFP4 1D blockwise quantization, blocksize is 16
- If columnwise: (round_to_multiple(K, 128), round_to_multiple(roundup(M / 16), 4))
- If rowwise: (round_to_multiple(M, 128), round_to_multiple(roundup(K / 16), 4))
Swizzle kernel will be performed before GEMM to suit the need of CuBLAS.
CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
"""
M, K = 1, 1
M = math.prod(shape[:-1])
K = shape[-1]
if columnwise:
outer = round_up_to_nearest_multiple(K, 128)
inner = round_up_to_nearest_multiple(math.ceil(M / NVFP4_BLOCK_SCALING_SIZE), 4)
return (outer, inner)
# rowwise
outer = round_up_to_nearest_multiple(M, 128)
inner = round_up_to_nearest_multiple(math.ceil(K / NVFP4_BLOCK_SCALING_SIZE), 4)
return (outer, inner)
@staticmethod
def get_columnwise_shape(shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of a tensor after columnwise quantization.
For NVFP4 columnwise quantization, it's performing 16x1 quantization block scaling.
Parameters
----------
shape : Iterable[int]
Original shape of the tensor
Returns
-------
Tuple[int, ...]
New shape with dimensions rearranged for columnwise layout.
For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1).
Returns empty tuple for empty input shape.
"""
if len(shape) == 0:
return tuple()
# and then after AG, a reorganize kernel will be called to restore the shape
colwise_shape = [shape[-1]]
for i in range(len(shape) - 1):
colwise_shape.append(shape[i])
return tuple(colwise_shape)
@staticmethod
def convert_shape_for_fp4(shape: Iterable[int]) -> Tuple[int, ...]:
"""Convert shape for FP4 data by dividing the last dimension by 2"""
shape = list(shape)
shape[-1] = shape[-1] // 2
return tuple(shape)
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
) -> NVFP4Tensor:
# Canonicalize tensor attributes
if device is None:
device = torch.device("cuda")
assert shape[-1] % NVFP4_BLOCK_SCALING_SIZE == 0, (
f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by"
f" {NVFP4_BLOCK_SCALING_SIZE}"
)
flat_first_dim = math.prod(shape[:-1])
assert flat_first_dim % NVFP4_BLOCK_SCALING_SIZE == 0, (
f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by"
f" {NVFP4_BLOCK_SCALING_SIZE}"
)
# Allocate FP4 data
data = None
scale_inv = None
amax_rowwise = None
if self.rowwise_usage:
data = torch.empty(self.convert_shape_for_fp4(shape), dtype=torch.uint8, device=device)
scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty(scale_shape, dtype=torch.uint8, device=device)
# Allocate per tensor scale inverse. FP32 format.
amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device)
# Allocate FP8 data transpose if needed
columnwise_data = None
columnwise_scale_inv = None
amax_columnwise = None
if self.columnwise_usage:
# enforce 2D shape to avoid [S, B, H] shape and B and be 1
# and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
shape_2d = tuple([flat_first_dim, shape[-1]])
columnwise_data = torch.empty(
self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)),
dtype=torch.uint8,
device=device,
)
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty(
columnwise_scale_shape, dtype=torch.uint8, device=device
)
amax_columnwise = torch.zeros(1, dtype=torch.float32, device=device)
# Construct FP8 tensor
return NVFP4Tensor(
shape=shape,
dtype=dtype,
rowwise_data=data,
rowwise_scale_inv=scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
amax_rowwise=amax_rowwise,
amax_columnwise=amax_columnwise,
fp4_dtype=self.dtype,
quantizer=self,
requires_grad=requires_grad,
)
def calibrate(self, tensor: torch.Tensor) -> None:
pass # Calibration is no-op
def _canonicalized_amax_reduction_group(self) -> dist_group_type:
"""Get process group for amax reduction"""
return canonicalize_process_group(self.amax_reduction_group)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return NVFP4BlockScaling
class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
"""Quantized tensor class with FP4 data
The tensor presents as having a standard, higher-precision dtype,
but the data itself is (scaled) FP4. For most tensor operations,
the data will be cast to the nominal dtype before performing the
operation.
Parameters
----------
rowwise_data: torch.Tensor
Raw FP4 data in a uint8 tensor (rowwise layout).
rowwise_scale_inv: torch.Tensor
Reciprocal of the scaling factor applied when
casting to FP4, i.e. the scaling factor that must
be applied when casting from FP4 to higher
precision (rowwise).
columnwise_data: torch.Tensor, optional
Raw FP4 data in a uint8 tensor (columnwise layout).
columnwise_scale_inv: torch.Tensor, optional
Reciprocal of the scaling factor for columnwise FP4 data.
amax_rowwise: torch.Tensor, optional
Rowwise amax tracking tensor.
amax_columnwise: torch.Tensor, optional
Columnwise amax tracking tensor.
fp4_dtype: TE_DType
The FP4 data type used for quantization.
quantizer: Quantizer
The quantizer instance used for this tensor.
dtype: torch.dtype, default = torch.float32
Nominal tensor datatype, used in dequantize.
"""
# NOTE: We reorder the *args so that we can instantiate a NVFP4TensorStorage with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__(
cls,
*args,
rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: Optional[torch.Tensor],
amax_rowwise: Optional[torch.Tensor],
amax_columnwise: Optional[torch.Tensor],
fp4_dtype: TE_DType,
quantizer: Quantizer,
**kwargs,
):
instance = super().__new__(
cls,
rowwise_data,
rowwise_scale_inv,
columnwise_data,
columnwise_scale_inv,
amax_rowwise,
amax_columnwise,
fp4_dtype,
quantizer,
*args,
**kwargs,
)
return instance
def __repr__(self, *, tensor_contents=None):
return f"NVFP4Tensor, data={self.dequantize(dtype=self.dtype)})"
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Construct plain PyTorch tensor from NVFP4Tensor
By default the resulting tensor's dtype is the
NVFP4Tensor's nominal dtype.
"""
# Convert PyTorch dtype to TE dtype
if dtype is None:
dtype = self.dtype
if torch.is_grad_enabled():
return _FromNVFP4Func.apply(self, dtype)
return _FromNVFP4Func.forward(None, self, dtype)
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if self._quantizer is not None:
return self._quantizer
return NVFP4Quantizer()
def quantize_(
self,
tensor: torch.Tensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> NVFP4Tensor:
"""Update FP8 data
Parameters
----------
tensor: torch.Tensor
Tensor to copy from
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
if isinstance(tensor, QuantizedTensor):
return self.quantize_(tensor.dequantize())
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
def detach(self) -> NVFP4Tensor:
# pylint: disable=missing-function-docstring
# TODO(ksivamani): Fix the detach bug
return NVFP4Tensor.make_like(self)
def clone(self) -> NVFP4Tensor:
# pylint: disable=missing-function-docstring
assert self._rowwise_data is not None
rowwise_data = self._rowwise_data.detach().clone()
columnwise_data = None
if self._columnwise_data is not None:
columnwise_data = self._columnwise_data.detach().clone()
return _IdentityFunc.apply(
self,
{
"rowwise_data": rowwise_data,
"columnwise_data": columnwise_data,
},
)
def view(self, *shape: Tuple[int]) -> NVFP4Tensor:
# pylint: disable=missing-function-docstring
return _ViewFunc.apply(self, shape)
def reshape(self, *shape: Tuple[int]) -> NVFP4Tensor:
# pylint: disable=missing-function-docstring
return _ReshapeFunc.apply(self, shape)
def contiguous(
self,
memory_format: torch.memory_format = torch.contiguous_format,
) -> NVFP4Tensor:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
if self._rowwise_data is not None and self._rowwise_data.is_contiguous(
memory_format=memory_format
):
return self
if self._columnwise_data is not None and self._columnwise_data.is_contiguous(
memory_format=memory_format
):
return self
raise ValueError("NVFP4Tensor does not support different memory formats!")
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# View op
if func == aten.view.default:
if len(args) != 2:
raise RuntimeError("Unexpected args for view op (expected 2 args, got {len(args)})")
tensor = args[0]
shape = args[1]
if shape == list(tensor.size()):
return tensor.detach()
return tensor.view(shape)
# NVFP4 dequantize not supported. Add manual support for needed funcs.
if func in (aten.empty_like.default, aten.zero_.default):
tensor = args[0]
data_init_func = torch.zeros_like if func == aten.zero_.default else torch.empty_like
scale_inv_init_func = (
torch.ones_like if func == aten.zero_.default else torch.empty_like
)
if tensor._rowwise_data is not None:
rowwise_data = data_init_func(tensor._rowwise_data)
rowwise_scale_inv = scale_inv_init_func(tensor._rowwise_scale_inv)
amax_rowwise = torch.zeros_like(tensor._amax_rowwise)
else:
rowwise_data, rowwise_scale_inv, amax_rowwise = None, None, None
if tensor._columnwise_data is not None:
columnwise_data = data_init_func(tensor._columnwise_data)
columnwise_scale_inv = scale_inv_init_func(tensor._columnwise_scale_inv)
amax_columnwise = torch.zeros_like(tensor._amax_columnwise)
else:
columnwise_data, columnwise_scale_inv, amax_columnwise = (
None,
None,
None,
)
return NVFP4Tensor(
shape=tensor.shape,
dtype=tensor.dtype,
fp4_dtype=tensor._fp4_dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
amax_rowwise=amax_rowwise,
amax_columnwise=amax_columnwise,
quantizer=tensor._quantizer,
requires_grad=tensor.requires_grad,
)
# Default case
return super().__torch_dispatch__(func, types, args, kwargs)
@classmethod
def _make_in_reduce_ex(
cls,
shape: torch.Size,
rowwise_data: torch.Tensor,
rowwise_scale_inv: torch.Tensor,
columnwise_data: torch.Tensor,
columnwise_scale_inv: torch.Tensor,
amax_rowwise: torch.Tensor,
amax_columnwise: torch.Tensor,
fp4_dtype: TE_DType,
dtype: torch.dtype,
quantizer: Quantizer,
) -> NVFP4Tensor:
"""Build NVFP4Tensor, for use in __reduce__
__reduce_ex__ assumes object constructor has positional
arguments.
"""
return NVFP4Tensor(
shape=shape,
dtype=dtype,
fp4_dtype=fp4_dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
amax_rowwise=amax_rowwise,
amax_columnwise=amax_columnwise,
quantizer=quantizer,
requires_grad=False,
)
def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling"""
return (
NVFP4Tensor._make_in_reduce_ex,
(
self.shape,
self._rowwise_data,
self._rowwise_scale_inv,
self._columnwise_data,
self._columnwise_scale_inv,
self._amax_rowwise,
self._amax_columnwise,
self._fp4_dtype,
self.dtype,
self._quantizer,
),
)
def _get_data(self) -> NVFP4Tensor:
"""Get tensor data property"""
return super().data
@torch.no_grad()
def _set_data(self, tensor: torch.Tensor) -> None:
"""Set tensor data property
Just takes FP8 data if setting from a NVFP4Tensor. Otherwise
casts to FP8.
"""
# Tensor device
new_device = tensor.device if tensor.is_cuda else self.device
if not devices_match(new_device, tensor.device):
tensor = tensor.to(device=new_device)
# Just copy FP8 data if other tensor is NVFP4Tensor
if isinstance(tensor, NVFP4Tensor):
if ( # pylint: disable=too-many-boolean-expressions
self.size() != tensor.size()
or self.stride() != tensor.stride()
or self.storage_offset() != tensor.storage_offset()
or self.dtype != tensor.dtype
or self.layout != tensor.layout
or not devices_match(self.device, new_device)
):
dummy_tensor = torch.Tensor._make_wrapper_subclass(
NVFP4Tensor,
tensor.size(),
strides=tensor.stride(),
storage_offset=tensor.storage_offset(),
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
device=new_device,
)
# pylint: disable=unnecessary-dunder-call
super(NVFP4Tensor, type(self)).data.__set__(self, dummy_tensor)
self._rowwise_data = tensor._rowwise_data
self._columnwise_data = tensor._columnwise_data
self._quantizer = tensor._quantizer
self._rowwise_scale_inv = tensor._rowwise_scale_inv
self._columnwise_scale_inv = tensor._columnwise_scale_inv
self._amax_rowwise = tensor._amax_rowwise
self._amax_columnwise = tensor._amax_columnwise
return
# Quantize to FP8
assert self._quantizer is not None, "Can't quantize without a quantizer"
self._quantizer.update_quantized(tensor, self)
if self.requires_grad != tensor.requires_grad:
self.requires_grad_(requires_grad=tensor.requires_grad)
# Cast to FP8 when setting NVFP4Tensor.data
data = property(_get_data, _set_data)
class _ViewFunc(torch.autograd.Function):
"""View function
View the NVFP4Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: NVFP4Tensor,
shape: Optional[list[int]] = None,
) -> NVFP4Tensor:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
cur_shape = tensor.shape
if ctx is not None:
ctx.shape = cur_shape
if shape is None:
return tensor
# Canonicalize shape
if not isinstance(shape, Iterable):
shape = [shape]
elif len(shape) == 1 and isinstance(shape[0], Iterable):
shape = shape[0]
if -1 in shape:
shape = list(shape)
d_inferred = -math.prod(cur_shape) // math.prod(shape)
for i, d in enumerate(shape):
if d == -1:
shape[i] = d_inferred
break
if shape[-1] != cur_shape[-1]:
raise RuntimeError(
"NVFP4Tensor does not support reshaping inner dimension "
f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})"
)
# Reshape data
new_rowwise_data = None
new_columnwise_data = None
if tensor._rowwise_data is not None:
if shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent row-wise data for NVFP4 tensor "
f"with shape={shape} as byte array."
)
byte_shape = list(shape[:-1]) + [shape[-1] // 2]
new_rowwise_data = tensor._rowwise_data.view(byte_shape)
if tensor._columnwise_data is not None:
columnwise_shape = (shape[-1], math.prod(shape[:-1]))
if columnwise_shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent column-wise data for NVFP4 tensor "
f"with shape={shape} as byte array."
)
byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2)
new_columnwise_data = tensor._columnwise_data.view(byte_shape)
# Construct tensor
return NVFP4Tensor(
shape,
tensor.dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=tensor._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=tensor._columnwise_scale_inv,
amax_rowwise=tensor._amax_rowwise,
amax_columnwise=tensor._amax_columnwise,
quantizer=tensor._quantizer,
fp4_dtype=tensor._fp4_dtype,
requires_grad=tensor.requires_grad,
)
@staticmethod
def backward(
ctx,
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad, NVFP4Tensor):
new_rowwise_data = None
new_columnwise_data = None
if grad._rowwise_data is not None:
if ctx.shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent row-wise data for NVFP4 tensor "
f"with shape={ctx.shape} as byte array."
)
byte_shape = list(ctx.shape[:-1]) + [ctx.shape[-1] // 2]
new_rowwise_data = grad._rowwise_data.view(byte_shape)
if grad._columnwise_data is not None:
columnwise_shape = (ctx.shape[-1], math.prod(ctx.shape[:-1]))
if columnwise_shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent column-wise data for NVFP4 tensor "
f"with shape={ctx.shape} as byte array."
)
byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2)
new_columnwise_data = grad._columnwise_data.view(byte_shape)
dgrad = NVFP4Tensor(
ctx.shape,
grad.dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=grad._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=grad._columnwise_scale_inv,
amax_rowwise=grad._amax_rowwise,
amax_columnwise=grad._amax_columnwise,
quantizer=grad._quantizer,
fp4_dtype=grad._fp4_dtype,
requires_grad=grad.requires_grad,
)
return dgrad, None
return grad.view(ctx.shape), None
class _ReshapeFunc(torch.autograd.Function):
"""Reshape function
Reshape the NVFP4Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: NVFP4Tensor,
shape: Optional[list[int]] = None,
) -> NVFP4Tensor:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
cur_shape = tensor.shape
if ctx is not None:
ctx.shape = cur_shape
if shape is None:
return tensor
# Canonicalize shape
if not isinstance(shape, Iterable):
shape = [shape]
elif len(shape) == 1 and isinstance(shape[0], Iterable):
shape = shape[0]
if -1 in shape:
shape = list(shape)
d_inferred = -math.prod(cur_shape) // math.prod(shape)
for i, d in enumerate(shape):
if d == -1:
shape[i] = d_inferred
break
if shape[-1] != cur_shape[-1]:
raise RuntimeError(
"NVFP4Tensor does not support reshaping inner dimension "
f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})"
)
# Reshape data
new_rowwise_data = None
new_columnwise_data = None
if tensor._rowwise_data is not None:
if shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent row-wise data for NVFP4 tensor "
f"with shape={shape} as byte array."
)
byte_shape = list(shape[:-1]) + [shape[-1] // 2]
new_rowwise_data = tensor._rowwise_data.reshape(byte_shape)
if tensor._columnwise_data is not None:
columnwise_shape = (shape[-1], math.prod(shape[:-1]))
if columnwise_shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent column-wise data for NVFP4 tensor "
f"with shape={shape} as byte array."
)
byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2)
new_columnwise_data = tensor._columnwise_data.reshape(byte_shape)
# Construct tensor
return NVFP4Tensor(
shape,
tensor.dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=tensor._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=tensor._columnwise_scale_inv,
amax_rowwise=tensor._amax_rowwise,
amax_columnwise=tensor._amax_columnwise,
quantizer=tensor._quantizer,
fp4_dtype=tensor._fp4_dtype,
requires_grad=tensor.requires_grad,
)
@staticmethod
def backward(
ctx,
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad, NVFP4Tensor):
new_rowwise_data = None
new_columnwise_data = None
if grad._rowwise_data is not None:
if ctx.shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent row-wise data for NVFP4 tensor "
f"with shape={ctx.shape} as byte array."
)
byte_shape = list(ctx.shape[:-1]) + [ctx.shape[-1] // 2]
new_rowwise_data = grad._rowwise_data.reshape(byte_shape)
if grad._columnwise_data is not None:
columnwise_shape = (ctx.shape[-1], math.prod(ctx.shape[:-1]))
if columnwise_shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent column-wise data for NVFP4 tensor "
f"with shape={ctx.shape} as byte array."
)
byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2)
new_columnwise_data = grad._columnwise_data.reshape(byte_shape)
dgrad = NVFP4Tensor(
ctx.shape,
grad.dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=grad._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=grad._columnwise_scale_inv,
amax_rowwise=grad._amax_rowwise,
amax_columnwise=grad._amax_columnwise,
quantizer=grad._quantizer,
fp4_dtype=grad._fp4_dtype,
requires_grad=grad.requires_grad,
)
return dgrad, None
return grad.view(ctx.shape), None
......@@ -5,7 +5,7 @@
"""Tensor with quantized data"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Any, Dict, Union
from typing import Callable, Optional, Tuple, Iterable, Any, Dict, Union
import abc
import copy
import warnings
......@@ -13,12 +13,11 @@ import warnings
import torch
from torch.utils._pytree import tree_map
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
class QuantizedTensorBase:
r"""Base class for all *TensorBase classes.
class QuantizedTensorStorage:
r"""Base class for all *TensorStorage classes.
This class (and its subclasses) are optimization for when
the full QuantizedTensor is not needed (when it is fully
......@@ -26,9 +25,9 @@ class QuantizedTensorBase:
PyTorch's autograd).
When creating a new tensor type X one should create both
XTensorBase class inheriting from QuantizedTensorBase and
XTensor inheriting from XTensorBase and QuantizedTensor.
XTensorBase should contain all data members needed to
XTensorStorage class inheriting from QuantizedTensorStorage and
XTensor inheriting from XTensorStorage and QuantizedTensor.
XTensorStorage should contain all data members needed to
implement the functionality of the tensor, while
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__)."""
......@@ -59,7 +58,7 @@ class QuantizedTensorBase:
f"{self.__class__.__name__} class does not implement update_usage function"
)
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]:
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement prepare_for_saving function"
......@@ -73,6 +72,30 @@ class QuantizedTensorBase:
f"{self.__class__.__name__} class does not implement restore_from_saved function"
)
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if self._quantizer is not None:
return self._quantizer
return self._build_default_quantizer()
def _build_default_quantizer(self) -> Quantizer:
"""Build default quantizer for the tensor"""
raise ValueError(
f"{self.__class__.__name__} has no quantizer "
"and no default quantizer is available defined in the subclass."
)
def quantize_(
self, tensor: torch.Tensor, *, noop_flag: Optional[torch.Tensor] = None
) -> QuantizedTensor:
"""Quantize tensor in-place"""
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
def update_quantizer(self, quantizer: Quantizer):
"""Update quantizer for the tensor"""
if self._quantizer is None:
......@@ -83,13 +106,13 @@ class QuantizedTensorBase:
def prepare_for_saving(
*tensors: Union[torch.Tensor, QuantizedTensorBase],
*tensors: Union[torch.Tensor, QuantizedTensorStorage],
) -> Tuple[
list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorBase]]
list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorStorage]]
]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save
the internal TensorBase types too."""
the internal *TensorStorage types too."""
tensor_list, tensor_objects_list = [], []
for tensor in tensors:
......@@ -104,12 +127,12 @@ def prepare_for_saving(
def restore_from_saved(
tensors: list[Optional[Union[torch.Tensor, QuantizedTensorBase]]],
tensors: list[Optional[Union[torch.Tensor, QuantizedTensorStorage]]],
saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]],
return_saved_tensors: bool = False,
) -> (
list[Optional[torch.Tensor | QuantizedTensorBase]]
| tuple[list[Optional[torch.Tensor | QuantizedTensorBase]], list[Optional[torch.Tensor]]]
list[Optional[torch.Tensor | QuantizedTensorStorage]]
| tuple[list[Optional[torch.Tensor | QuantizedTensorStorage]], list[Optional[torch.Tensor]]]
):
"""Recombine the tensor data and metadata during backward pass."""
tensor_objects = []
......@@ -178,7 +201,6 @@ class Quantizer(abc.ABC):
")"
)
@abc.abstractmethod
def update_quantized(
self,
src: torch.Tensor,
......@@ -187,6 +209,9 @@ class Quantizer(abc.ABC):
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Quantize tensor in-place"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement update_quantized"
)
def quantize(
self,
......@@ -199,8 +224,14 @@ class Quantizer(abc.ABC):
if out is not None:
return self.update_quantized(tensor, out)
if (not self.internal) and torch.is_grad_enabled():
return _QuantizeFunc.apply(tensor, self)
return _QuantizeFunc.forward(None, tensor, self)
return _QuantizeFunc.apply(tensor, self.quantize_impl)
return _QuantizeFunc.forward(None, tensor, self.quantize_impl)
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement quantize_impl function"
)
def multi_quantize(self, list_of_tensors):
"""Quantize multiple tensors"""
......@@ -213,7 +244,6 @@ class Quantizer(abc.ABC):
"""Quantize tensor"""
return self.quantize(tensor)
@abc.abstractmethod
def make_empty(
self,
shape: Iterable[int],
......@@ -222,8 +252,11 @@ class Quantizer(abc.ABC):
device: Optional[torch.device] = None,
) -> QuantizedTensor:
"""Construct quantized tensor with uninitialized data"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement make_empty function, "
"required for construction of unintialized quantized tensor"
)
@abc.abstractmethod
def calibrate(self, tensor: torch.Tensor) -> None:
"""Calibrate quantizer state
......@@ -252,34 +285,47 @@ class Quantizer(abc.ABC):
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Symbolic function for ONNX export"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement onnx_quantize"
)
def onnx_dequantize(self, tensor) -> torch.Tensor:
"""Symbolic function for ONNX export"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement onnx_dequantize"
)
@abc.abstractmethod
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Returns recipe class that is compatible with this quantizer"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_compatible_recipe"
)
def supports_only_rowwise_all_gather(self) -> bool:
"""Returns True if the quantizer supports only rowwise all-gather"""
return False
def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument
"""Returns whether or not given tensor can be quantized"""
return True
class _QuantizeFunc(torch.autograd.Function):
"""Cast to FP8 from other dtype"""
"""Quantize tensor"""
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: torch.Tensor,
quantizer: Quantizer,
quantize_impl: Callable,
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
return tex.quantize(tensor, quantizer)
return quantize_impl(tensor)
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor # unused
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Storage for quantized tensors."""
from .float8_tensor_storage import Float8TensorStorage # noqa: F401
from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401
from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401
from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401
......@@ -14,7 +14,7 @@ from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
from ..quantized_tensor import QuantizedTensorBase
from ..quantized_tensor import QuantizedTensorStorage
from ...constants import TE_DType_To_Torch
......@@ -23,7 +23,7 @@ from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
class Float8BlockwiseQTensorBase(QuantizedTensorBase):
class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
"""Mixin class that holds data attributes of Float8BlockwiseQTensor.
Float8BlockwiseQTensor inherits from the PyTorch tensor class and this
......@@ -54,7 +54,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
*args,
**kwargs,
):
if cls is Float8BlockwiseQTensorBase:
if cls is Float8BlockwiseQTensorStorage:
instance = object.__new__(cls)
else:
instance = super().__new__(cls, *args, **kwargs)
......@@ -99,7 +99,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]:
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorStorage]:
"""
Prepare the tensor base for saving for backward
"""
......@@ -367,7 +367,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
data = self.dequantize()
descriptor = "columnwise"
return (
"Float8BlockwiseQTensorBase("
"Float8BlockwiseQTensorStorage("
f"fp8_dtype={self._fp8_dtype}, "
f"{descriptor}_scaled_data={data}"
)
......
......@@ -12,7 +12,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase
from ..quantized_tensor import QuantizedTensorStorage
from ...constants import TE_DType as torch_to_transformer_engine_dtype
......@@ -27,7 +27,7 @@ class _FromFloat8Func(torch.autograd.Function):
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: Float8TensorBase,
tensor: Float8TensorStorage,
dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
......@@ -52,7 +52,7 @@ class _FromFloat8Func(torch.autograd.Function):
return grad, None
class Float8TensorBase(QuantizedTensorBase):
class Float8TensorStorage(QuantizedTensorStorage):
"""Mixin class that holds data attributes of Float8Tensor.
Float8Tensor inherits from the PyTorch tensor class and this mixin
......@@ -81,7 +81,7 @@ class Float8TensorBase(QuantizedTensorBase):
quantizer: Optional[Quantizer] = None,
**kwargs,
):
if cls is Float8TensorBase:
if cls is Float8TensorStorage:
instance = object.__new__(cls)
else:
instance = super().__new__(cls, *args, **kwargs)
......@@ -116,7 +116,7 @@ class Float8TensorBase(QuantizedTensorBase):
"quantizer": self._quantizer,
}
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]:
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward"""
tensors = [self._data, self._transpose, self._scale_inv]
self._data = None
......@@ -163,7 +163,7 @@ class Float8TensorBase(QuantizedTensorBase):
if out_transpose_shape[0] != shape[-1] or out_transpose_shape[1:] != shape[:-1]:
out_transpose = None
return Float8TensorBase(
return Float8TensorStorage(
data=out_data,
fp8_scale_inv=self._scale_inv,
fp8_dtype=self._fp8_dtype,
......@@ -173,7 +173,7 @@ class Float8TensorBase(QuantizedTensorBase):
def __repr__(self):
return (
"Float8TensorBase("
"Float8TensorStorage("
f"fp8_dtype={self._fp8_dtype}, "
f"scale_inv={self._scale_inv.item()}, "
f"data={self.dequantize()}"
......
......@@ -13,7 +13,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase
from ..quantized_tensor import QuantizedTensorStorage
from ...constants import TE_DType as torch_to_transformer_engine_dtype
......@@ -28,7 +28,7 @@ class _FromMXFP8Func(torch.autograd.Function):
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: MXFP8TensorBase,
tensor: MXFP8TensorStorage,
dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
......@@ -49,7 +49,7 @@ class _FromMXFP8Func(torch.autograd.Function):
return grad, None
class MXFP8TensorBase(QuantizedTensorBase):
class MXFP8TensorStorage(QuantizedTensorStorage):
"""Mixin class that holds data attributes of MXFP8Tensor.
MXFP8Tensor inherits from the PyTorch tensor class and this mixin
......@@ -77,7 +77,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
*args,
**kwargs,
):
if cls is MXFP8TensorBase:
if cls is MXFP8TensorStorage:
instance = object.__new__(cls)
else:
instance = super().__new__(cls, *args, **kwargs)
......@@ -112,7 +112,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
"quantizer": self._quantizer,
}
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]:
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]:
"""Prepare the tensor base for saving for backward"""
tensors = [
self._rowwise_data,
......@@ -192,7 +192,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
if cur_columnwise_data is not None:
new_columnwise_data = cur_columnwise_data.view(*shape)
return MXFP8TensorBase(
return MXFP8TensorStorage(
rowwise_data=new_rowwise_data,
rowwise_scale_inv=self._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
......@@ -205,7 +205,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
data_rowwise = self.dequantize()
return (
"MXFP8TensorBase("
"MXFP8TensorStorage("
f"fp8_dtype={self._fp8_dtype}, "
f"rowwise_scaled_data={data_rowwise}"
f"rowwise_scale_inv={self._rowwise_scale_inv}, "
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Mixin class holding data specific for NVFP4Tensor"""
from __future__ import annotations
from collections.abc import Iterable
import functools
import math
from typing import Any, Dict, Optional, Tuple, Union
import warnings
import torch
# import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorStorage
# from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
@functools.lru_cache(maxsize=None)
def _fp4_e2m1_vals(device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""Values representable in FP4 E2M1 format"""
return torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0],
device=device,
dtype=dtype,
)
class _FromNVFP4Func(torch.autograd.Function):
"""Cast from NVFP4 to other dtype"""
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: NVFP4TensorStorage,
dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Dequantize row-wise data
if tensor._rowwise_data is not None:
### TODO(tmoon): Debug dequantize kernel and remove unfused impl
# return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype])
# Tensor properties
shape = list(tensor._rowwise_data.size())
shape[-1] *= 2
device = tensor._rowwise_data.device
# Convert FP4E2M1 values to FP32
data = tensor._rowwise_data.view(torch.uint8).to(torch.int32)
data = torch.stack((data & 0x0F, data >> 4), dim=-1).reshape(shape)
data = _fp4_e2m1_vals(device, dtype=torch.float32)[data]
data = data.to(torch.float32).contiguous()
# Convert FP8E4M3 block scales to FP32
block_scales = tensor._rowwise_scale_inv
block_scales = block_scales.reshape(-1, block_scales.size(-1))
block_scales = block_scales[: math.prod(shape[:-1]), : shape[-1] // 16]
block_scales = block_scales.view(torch.float8_e4m3fn).to(torch.float32)
# Convert amax to FP32 tensor scale
tensor_scale = tensor._amax_rowwise / (6.0 * 448.0) # Scale by FP4E2M1 and FP8E4M3 max
# Apply scales
block_data = data.view(-1, 16)
block_data *= tensor_scale.view(()) * block_scales.reshape(-1, 1)
return data.to(dtype)
if tensor._columnwise_data is not None:
raise NotImplementedError("Dequantizing column-wise NVFP4 data is not implemented yet!")
raise ValueError("Attempted to dequantize NVFP4 tensor with no data")
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return grad, None
class NVFP4TensorStorage(QuantizedTensorStorage):
"""Mixin class that holds data attributes of NVFP4Tensor.
NVFP4Tensor inherits from the PyTorch tensor class and this mixin
class. If this class is instantiated directly, it has the same
data, lower CPU overhead, and less functionality. It should only
be instantiated directly for performance-critical internal usage.
"""
_rowwise_data: Optional[torch.Tensor]
_columnwise_data: Optional[torch.Tensor]
_quantizer: Optional[Quantizer]
_rowwise_scale_inv: torch.Tensor
_columnwise_scale_inv: torch.Tensor
_fp4_dtype: TE_DType
_amax_rowwise: torch.Tensor
_amax_columnwise: torch.Tensor
def __new__(
cls,
rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: torch.Tensor,
columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: torch.Tensor,
amax_rowwise: torch.Tensor,
amax_columnwise: torch.Tensor,
fp4_dtype: TE_DType,
quantizer: Optional[Quantizer],
*args,
**kwargs,
):
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._fp4_dtype = fp4_dtype
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
instance._amax_rowwise = amax_rowwise
instance._amax_columnwise = amax_columnwise
return instance
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
self._amax_rowwise,
self._amax_columnwise,
):
if t is not None:
t.data = _empty_tensor()
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
"rowwise_data": self._rowwise_data,
"rowwise_scale_inv": self._rowwise_scale_inv,
"columnwise_data": self._columnwise_data,
"columnwise_scale_inv": self._columnwise_scale_inv,
"amax_rowwise": self._amax_rowwise,
"amax_columnwise": self._amax_columnwise,
"fp4_dtype": self._fp4_dtype,
"quantizer": self._quantizer,
}
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]:
"""Prepare the tensor base for saving for backward"""
tensors = [
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
self._amax_rowwise,
self._amax_columnwise,
]
self._rowwise_data = None
self._columnwise_data = None
self._rowwise_scale_inv = None
self._columnwise_scale_inv = None
self._amax_rowwise = None
self._amax_columnwise = None
return tensors, self
def restore_from_saved(
self, tensors: list[Optional[torch.Tensor]]
) -> list[Optional[torch.Tensor]]:
"""Restore the tensor base data from the saved tensors list."""
self._rowwise_data = tensors[0]
self._columnwise_data = tensors[1]
self._rowwise_scale_inv = tensors[2]
self._columnwise_scale_inv = tensors[3]
self._amax_rowwise = tensors[4]
self._amax_columnwise = tensors[5]
return tensors[6:]
def get_data_tensors(self):
"""Get this Tensor's data."""
return self._rowwise_data, self._columnwise_data
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Dequantize to a higher precision."""
return _FromNVFP4Func.forward(None, self, dtype)
def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]:
# pylint: disable=missing-function-docstring
# Infer tensor shape
shape = None
if self._rowwise_data is not None:
byte_shape = list(self._rowwise_data.size())
shape = byte_shape[:-1] + [byte_shape[-1] * 2]
elif self._columnwise_data is not None:
warnings.warn("Attempting to get shape of NVFP4 tensor with only column-wise data.")
byte_shape = list(self._columnwise_data.size())
shape = byte_shape[1:-1] + [byte_shape[-1] * 2, byte_shape[0]]
if shape is None:
raise RuntimeError("Attempted to get shape of NVFP4 tensor with no data")
# Return shape or dim
if dim is None:
return torch.Size(shape)
return shape[dim]
def view(self, shape: torch.Size):
# pylint: disable=missing-function-docstring
# Return input tensor if view not needed
cur_shape = self.size()
if shape is None or shape == cur_shape:
return self
# Canonicalize shape
if not isinstance(shape, Iterable):
shape = [shape]
elif len(shape) == 1 and isinstance(shape[0], Iterable):
shape = shape[0]
if -1 in shape:
shape = list(shape)
d_inferred = -math.prod(cur_shape) // math.prod(shape)
for i, d in enumerate(shape):
if d == -1:
shape[i] = d_inferred
break
if shape[-1] != cur_shape[-1]:
raise RuntimeError(
"NVFP4Tensor does not support reshaping inner dimension "
f"(attempted to reshape dims={tuple(cur_shape)} to {tuple(shape)})"
)
# Reshape data
new_rowwise_data = None
new_columnwise_data = None
if self._rowwise_data is not None:
if shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent row-wise data for NVFP4 tensor "
f"with shape={shape} as byte array."
)
byte_shape = list(shape[:-1]) + [shape[-1] // 2]
new_rowwise_data = self._rowwise_data.view(byte_shape)
if self._columnwise_data is not None:
columnwise_shape = (shape[-1], math.prod(shape[:-1]))
if columnwise_shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent column-wise data for NVFP4 tensor "
f"with shape={shape} as byte array."
)
byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2)
new_columnwise_data = self._columnwise_data.view(byte_shape)
# Construct tensor
return NVFP4TensorStorage(
rowwise_data=new_rowwise_data,
rowwise_scale_inv=self._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=self._columnwise_scale_inv,
amax_rowwise=self._amax_rowwise,
amax_columnwise=self._amax_columnwise,
quantizer=self._quantizer,
fp4_dtype=self._fp4_dtype,
)
def __repr__(self):
data_rowwise = self.dequantize()
return (
"NVFP4TensorStorage("
f"rowwise_scaled_data={data_rowwise},"
f"rowwise_scale_inv={self._rowwise_scale_inv},"
f"amax_rowwise={self._amax_rowwise},"
f"amax_columnwise={self._amax_columnwise},"
")"
)
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""
For the NVFP4 format, columnwise scaled output is only produced by x2
scaling kernels, so this function only disables usages.
"""
# Default usage is based on available data
if rowwise_usage is None:
rowwise_usage = self._rowwise_data is not None
if columnwise_usage is None:
columnwise_usage = self._columnwise_data is not None
# Update row-scaled data
if rowwise_usage:
if self._rowwise_data is None:
raise RuntimeError(
"Requested row-wise usage, but NVFP4Tensor is missing row-scaled NVFP4 data"
)
if self._rowwise_scale_inv is None:
raise RuntimeError(
"Requested row-wise usage, but NVFP4Tensor is missing row-scaled scale-inverses"
)
if self._amax_rowwise is None:
raise RuntimeError(
"Requested row-wise usage, but NVFP4Tensor is missing per tensor"
" row-scaled scale-inverse"
)
else:
self._rowwise_data = None
self._rowwise_scale_inv = None
self._amax_rowwise = None
# Update column-scaled data
if columnwise_usage:
if self._columnwise_data is None:
raise RuntimeError(
"Requested column-wise usage, but NVFP4Tensor is missing column-scaled FP8 data"
)
if self._columnwise_scale_inv is None:
raise RuntimeError(
"Requested column-wise usage, "
"but NVFP4Tensor is missing column-scaled scale-inverses"
)
if self._amax_columnwise is None:
raise RuntimeError(
"Requested column-wise usage, "
"but NVFP4Tensor is missing per tensor column-scaled scale-inverse"
)
else:
self._columnwise_data = None
self._columnwise_scale_inv = None
self._amax_columnwise = None
......@@ -4,12 +4,14 @@
"""Helper functions for using fp8 tensors as weights"""
import os
from typing import Optional, Union
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
from .quantized_tensor import QuantizedTensor
from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage
from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
......@@ -455,3 +457,20 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
tex.fp8_block_scaling_partial_cast(
master_weight, model_weight_fragment, scale, h, w, start_offset, block_len, fp8_dtype
)
def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool:
"""Check if an environment or object is using experimental Kitchen middleware.
Returns False if x is a torch.Tensor.
"""
# Detect if the environment is experimental
if x is None:
return int(os.getenv("QAT_PARAMS", "0")) > 0
# Detect if the object is experimental
if isinstance(x, torch.Tensor):
return False
if not isinstance(x, (Quantizer, QuantizedTensorStorage)):
raise AssertionError("Object must be a Quantizer or QuantizedTensorStorage instance")
return hasattr(x, "experimental") and x.experimental
......@@ -191,6 +191,17 @@ class TransformerLayer(torch.nn.Module):
and `DotProductAttention` modules.
name: str, default = `None`
name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters
----------------------
......@@ -306,6 +317,7 @@ class TransformerLayer(torch.nn.Module):
qk_norm_type: Optional[str] = None,
qk_norm_eps: float = 1e-6,
qk_norm_before_rope: bool = False,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -362,6 +374,7 @@ class TransformerLayer(torch.nn.Module):
self.get_rng_state_tracker = get_rng_state_tracker
self.attn_input_format = attn_input_format
self.softmax_type = softmax_type
self.name = name
......@@ -397,6 +410,7 @@ class TransformerLayer(torch.nn.Module):
"qkv_format": self.attn_input_format,
"seq_length": seq_length,
"micro_batch_size": micro_batch_size,
"softmax_type": self.softmax_type,
}
self.self_attention = MultiheadAttention(
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""NVFP4 padding kernels
TODO(ksivamani): Documentation
"""
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=1),
],
key=["out_dim0", "out_dim1"],
)
@triton.jit
def zero_pad_kernel(
inp_ptr,
out_ptr,
in_dim0: tl.constexpr,
in_dim1: tl.constexpr,
out_dim0: tl.constexpr,
out_dim1: tl.constexpr,
in_s0,
in_s1,
out_s0,
out_s1,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Pads a tensor assuming it's a columnwise scaling inverse."""
# tile over OUTPUT coordinates
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # output rows
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # output cols
om = offs_m[:, None]
on = offs_n[None, :]
# edge masking for output
out_mask = (om < out_dim0) & (on < out_dim1)
# valid input region is simply top-left (no offsets)
in_mask = (om < in_dim0) & (on < in_dim1)
# load valid input, else zero (masked load touches memory only where True)
x = tl.load(inp_ptr + om * in_s0 + on * in_s1, mask=in_mask, other=0)
# store to output (only within bounds of the output tile)
tl.store(out_ptr + om * out_s0 + on * out_s1, x, mask=out_mask)
def pad_columnwise_scale_inv(inp: torch.Tensor) -> torch.Tensor:
"""Pads a tensor assuming it's a columnwise scaling inverse."""
assert inp.ndim == 2
dim0, dim1 = inp.shape
pad_x = (128 - dim0 % 128) % 128
pad_y = (4 - dim1 % 4) % 4
out_x = dim0 + pad_x
out_y = dim1 + pad_y
out = torch.empty((out_x, out_y), device=inp.device, dtype=inp.dtype)
in_s0, in_s1 = inp.stride()
out_s0, out_s1 = out.stride()
BLOCK_M, BLOCK_N = 128, 128
grid = (triton.cdiv(out_x, BLOCK_M), triton.cdiv(out_y, BLOCK_N))
zero_pad_kernel[grid](
inp,
out,
dim0,
dim1,
out_x,
out_y,
in_s0,
in_s1,
out_s0,
out_s1,
)
return out
......@@ -324,7 +324,8 @@ def _permute_kernel(
pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cur_off < hidden_size
input_off = pid_t * stride_input_token + cur_off * stride_input_hidden
src_row = pid_t.to(tl.int64)
input_off = src_row * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
if PERMUTE_SCALE:
mask_scale = cur_off < scale_hidden_dim
......@@ -338,7 +339,7 @@ def _permute_kernel(
for idx in tl.range(n_routed):
dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
)
).to(tl.int64)
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
if PERMUTE_SCALE:
permuted_scale_off = (
......@@ -519,7 +520,7 @@ def _unpermute_kernel(
for idx in tl.range(n_routed):
src_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
)
).to(tl.int64)
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
......@@ -550,7 +551,8 @@ def _unpermute_kernel(
prob = tl.load(permuted_probs_ptr + permuted_prob_off)
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
accumulator = accumulator.to(data_type)
output_off = pid_t * stride_output_token + current_offset * stride_output_hidden
dst_row = pid_t.to(tl.int64)
output_off = dst_row * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask)
......@@ -681,7 +683,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
for idx in tl.range(n_routed):
dst_row = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert
)
).to(tl.int64)
expert_idx = tl.load(
row_id_map_ptr
+ pid * stride_row_id_map_token
......@@ -692,8 +694,10 @@ def _unpermute_bwd_with_merging_probs_kernel(
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
src_row = pid.to(tl.int64)
input_off = (
pid * stride_fwd_output_grad_token + current_offset * stride_fwd_output_grad_hidden
src_row * stride_fwd_output_grad_token
+ current_offset * stride_fwd_output_grad_hidden
)
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
......@@ -902,11 +906,11 @@ def _sort_chunks_by_map_kernel(
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
if FORWARD:
src_row = pid_t
dst_row = tl.load(row_id_map_ptr + pid_t)
src_row = pid_t.to(tl.int64)
dst_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64)
else:
src_row = tl.load(row_id_map_ptr + pid_t)
dst_row = pid_t
src_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64)
dst_row = pid_t.to(tl.int64)
current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden
......
......@@ -10,10 +10,14 @@ import os
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version
from .tensor.quantized_tensor import Quantizer
from torch.utils.cpp_extension import IS_HIP_EXTENSION
__all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"]
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
"""Check if any of the given tensors require gradient."""
for tensor in tensors:
......@@ -182,7 +186,7 @@ def combine_tensors(
num_tensors = len(tensors)
new_shape = list(tensors[0].shape)
new_shape.insert(dim, num_tensors)
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
if isinstance(tensors[0], Float8Tensor):
new_stride = list(tensors[0]._data.stride())
......@@ -222,14 +226,16 @@ class SplitAlongDim(torch.autograd.Function):
# pylint: disable=missing-function-docstring
ctx.split_dim = split_dim
ctx.split_size_or_sections = split_size_or_sections
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import (
Float8TensorStorage,
)
if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance(
if isinstance(mixed_x_layer, Float8TensorStorage) and not isinstance(
mixed_x_layer, Float8Tensor
):
return tuple(
Float8TensorBase(
Float8TensorStorage(
fp8_scale_inv=mixed_x_layer._scale_inv,
fp8_dtype=mixed_x_layer._fp8_dtype,
data=x.squeeze(split_dim) if squeeze else x,
......@@ -274,7 +280,7 @@ class SplitAlongDim(torch.autograd.Function):
split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
dims = len(grad_outputs[0].shape)
split_dim = (ctx.split_dim + dims) % dims
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
if isinstance(grad_outputs[0], Float8Tensor):
noop_ok = True
......@@ -454,14 +460,23 @@ if IS_HIP_EXTENSION:
import re
return (re.search('BW', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
def assert_dim_for_all_gather(
tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer
) -> None:
"""Assert that tensor dimensions are supported for all-gather"""
if with_all_gather:
assert quantizer.is_quantizable(tensor), (
"All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__
)
def is_bf16_compatible() -> None:
def is_bf16_compatible() -> bool:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher.
"""
if IS_HIP_EXTENSION:
# only MI200 and MI300 machines support bf16
if get_device_compute_capability() == (9, 4) or is_mi200() or is_K100_AI() or is_BW():
if get_device_compute_capability() >= (9, 4) or is_mi200() or is_K100_AI() or is_BW():
return True
else:
return False
......@@ -469,6 +484,29 @@ def is_bf16_compatible() -> None:
return torch.cuda.get_device_capability()[0] >= 8
def is_bf16_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]:
"""
Determine whether bfloat16 (BF16) computation is supported on the current device.
Parameters
----------
return_reason : bool, optional
If ``False`` (default), return only a boolean indicating BF16 availability.
If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides
a human-readable explanation when BF16 is not available. When BF16 is available,
the reason will be an empty string.
"""
available = is_bf16_compatible()
if not return_reason:
return available
reason = (
"" if available else "BF16 support requires a GPU with compute capability 8.0 or higher."
)
return available, reason
@functools.lru_cache(maxsize=None)
def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool:
"""Checks whether the device supports
......@@ -486,6 +524,8 @@ def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool:
@functools.lru_cache(maxsize=None)
def get_cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
import transformer_engine.pytorch.cpp_extensions as ext
# ROCm fused attn does not use cudnn, return high numbers to avoid tests filtering out
if IS_HIP_EXTENSION:
return (99, 0, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment