Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""MoE Permutaion API"""
"""MoE Permutation API"""
import warnings
from typing import Optional, Tuple
import torch
......@@ -191,6 +191,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
routing_map: torch.Tensor,
num_out_tokens: int,
probs: torch.Tensor,
pad_offsets: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=missing-function-docstring
if not inp.numel():
......@@ -201,6 +202,8 @@ class _moe_permute_mask_map(torch.autograd.Function):
assert routing_map.is_cuda, "TransformerEngine needs CUDA."
if probs is not None:
assert probs.is_cuda, "TransformerEngine needs CUDA."
if pad_offsets is not None:
assert pad_offsets.is_cuda, "TransformerEngine needs CUDA."
assert inp.size(0) == routing_map.size(0), "Permute not possible"
num_tokens, hidden_size = inp.size()
......@@ -250,6 +253,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map,
probs,
fp8_scale,
pad_offsets,
num_tokens,
num_experts,
num_out_tokens,
......@@ -290,9 +294,10 @@ class _moe_permute_mask_map(torch.autograd.Function):
columnwise_scale_inv=None,
quantizer=None,
requires_grad=output.requires_grad,
with_gemm_swizzled_scales=False,
)
ctx.save_for_backward(row_id_map)
ctx.save_for_backward(row_id_map, pad_offsets)
ctx.num_experts = num_experts
ctx.num_tokens = num_tokens
ctx.hidden_size = hidden_size
......@@ -307,12 +312,12 @@ class _moe_permute_mask_map(torch.autograd.Function):
) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
if not permuted_act_grad.numel():
return permuted_act_grad, None, None, ctx.probs
return permuted_act_grad, None, None, ctx.probs, None
act_grad = None
probs_grad = None
if ctx.needs_input_grad[0]:
(row_id_map,) = ctx.saved_tensors
row_id_map, pad_offsets = ctx.saved_tensors
assert not isinstance(
permuted_act_grad, QuantizedTensor
), "The backward of moe_permute does not support FP8."
......@@ -321,13 +326,14 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map,
None,
permuted_probs_grad,
pad_offsets,
ctx.num_tokens,
ctx.num_experts,
ctx.hidden_size,
)
if not ctx.needs_input_grad[3]:
probs_grad = None
return act_grad, None, None, probs_grad
return act_grad, None, None, probs_grad, None
class _moe_unpermute_mask_map(torch.autograd.Function):
......@@ -340,6 +346,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map: torch.Tensor,
merging_probs: Optional[torch.Tensor],
restore_shape: Optional[torch.Size],
pad_offsets: Optional[torch.Tensor],
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if not inp.numel():
......@@ -358,6 +365,8 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
# Device check
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
if pad_offsets is not None:
assert pad_offsets.is_cuda, "TransformerEngine needs CUDA."
assert not isinstance(
inp, QuantizedTensor
......@@ -367,15 +376,16 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map,
merging_probs,
None,
pad_offsets,
num_tokens,
num_experts,
hidden_size,
)
if with_probs:
ctx.save_for_backward(inp, row_id_map, merging_probs)
ctx.save_for_backward(inp, row_id_map, merging_probs, pad_offsets)
else:
ctx.save_for_backward(row_id_map)
ctx.save_for_backward(row_id_map, pad_offsets)
ctx.num_experts = num_experts
ctx.num_tokens = num_tokens
ctx.num_permuted_tokens = inp.size(0)
......@@ -387,15 +397,15 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
def backward(ctx, unpermuted_act_grad):
# pylint: disable=missing-function-docstring
if not unpermuted_act_grad.numel():
return unpermuted_act_grad, None, ctx.merging_probs, None
return unpermuted_act_grad, None, ctx.merging_probs, None, None
act_grad = None
probs_grad = None
if ctx.needs_input_grad[0]:
if ctx.with_probs:
fwd_input, row_id_map, merging_probs = ctx.saved_tensors
fwd_input, row_id_map, merging_probs, pad_offsets = ctx.saved_tensors
else:
(row_id_map,) = ctx.saved_tensors
row_id_map, pad_offsets = ctx.saved_tensors
fp8 = isinstance(unpermuted_act_grad, QuantizedTensor)
per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor)
......@@ -441,6 +451,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map,
fwd_input,
merging_probs,
pad_offsets,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
......@@ -453,6 +464,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map,
None,
fp8_scale,
pad_offsets,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
......@@ -493,11 +505,12 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
columnwise_scale_inv=None,
quantizer=None,
requires_grad=act_grad.requires_grad,
with_gemm_swizzled_scales=False,
)
if not ctx.needs_input_grad[2]:
probs_grad = None
return act_grad, None, probs_grad, None
return act_grad, None, probs_grad, None, None
def moe_permute(
......@@ -514,22 +527,22 @@ def moe_permute(
Parameters
----------
inp: torch.Tensor
inp : torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
routing_map: torch.Tensor
routing_map : torch.Tensor
The token to expert mapping tensor.
If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not.
If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'.
The values in it are the routed expert indices.
num_out_tokens: int, default = -1
num_out_tokens : int, default = -1
The effective output token count, representing the number of tokens not dropped.
By default, set to '-1', meaning no tokens are dropped.
max_token_num: int, default = -1
max_token_num : int, default = -1
The maximum number of tokens, used for workspace allocation.
By default, set to '-1', meaning the calculation of the size of workspace is
automatically taken over by the operator.
map_type: str, default = 'mask'
map_type : str, default = 'mask'
Type of the routing map tensor.
Options are: 'mask', 'index'.
Refer to `routing_map` for more details.
......@@ -537,7 +550,9 @@ def moe_permute(
if map_type == "index":
return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num)
if map_type == "mask":
output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None)
output, row_id_map, _ = _moe_permute_mask_map.apply(
inp, routing_map, num_out_tokens, None, None
)
return output, row_id_map
raise ValueError("map_type should be one of 'mask' or 'index'")
......@@ -556,25 +571,81 @@ def moe_permute_with_probs(
Parameters
----------
inp: torch.Tensor
inp : torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor
probs : torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens, num_experts]. It will be permuted with the tokens
according to the routing_map.
routing_map: torch.Tensor
routing_map : torch.Tensor
The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not.
num_out_tokens: int, default = -1
num_out_tokens : int, default = -1
The effective output token count, representing the number of tokens not dropped.
By default, set to '-1', meaning no tokens are dropped.
"""
output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
inp, routing_map, num_out_tokens, probs
inp, routing_map, num_out_tokens, probs, None
)
return output, permuted_probs, row_id_map
def moe_permute_and_pad_with_probs(
inp: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
tokens_per_expert: torch.Tensor,
align_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
"""
Permute the tokens and probs based on the routing_map.
Token with the same index will be grouped together.
Tokens with the same designated expert will be grouped together.
The routing_map indicates which experts were selected by each token.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens, num_experts]. It will be permuted with the tokens
according to the routing_map.
routing_map: torch.Tensor
The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not.
tokens_per_expert : torch.Tensor
Tensor of shape `[num_experts]` containing actual token counts per expert.
align_size : int
the alignment size for the input tensor.
"""
assert (
tokens_per_expert is not None
), "tokens_per_expert must be provided to the fused permute padding function."
assert align_size > 0, f"align_size must be positive, got {align_size}"
# Ensure tokens_per_expert is on the same device as input to avoid device transfers
if tokens_per_expert.device != inp.device:
tokens_per_expert = tokens_per_expert.to(inp.device)
# Calculate aligned token counts per expert
target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long()
if torch.equal(tokens_per_expert, target_tokens_per_expert):
pad_offsets = None
else:
pad_lengths = target_tokens_per_expert - tokens_per_expert
cum_pad = torch.cumsum(pad_lengths, dim=0)
pad_offsets = torch.cat(
[torch.zeros(1, dtype=cum_pad.dtype, device=inp.device), cum_pad[:-1]]
)
output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
inp, routing_map, target_tokens_per_expert.sum().item(), probs, pad_offsets
)
return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert
def moe_unpermute(
inp: torch.Tensor,
row_id_map: torch.Tensor,
......@@ -582,6 +653,7 @@ def moe_unpermute(
restore_shape: Optional[torch.Size] = None,
map_type: str = "mask",
probs: Optional[torch.Tensor] = None,
pad_offsets: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
......@@ -589,22 +661,26 @@ def moe_unpermute(
Parameters
----------
inp: torch.Tensor
inp : torch.Tensor
Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
row_id_map: torch.Tensor
row_id_map : torch.Tensor
The tensor of a mapping table for sorted indices used to unpermute the tokens,
which is the second output tensor of `Permute`.
merging_probs: torch.Tensor, default = None
merging_probs : torch.Tensor, default = None
The tensor of probabilities corresponding to the permuted tokens. If provided,
the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
restore_shape: torch.Size, default = None
restore_shape : torch.Size, default = None
The output shape after the unpermute operation.
map_type: str, default = 'mask'
map_type : str, default = 'mask'
Type of the routing map tensor. Should be the same as the value passed to moe_permute.
Options are: 'mask', 'index'.
probs: torch.Tensor, default = None
probs : torch.Tensor, default = None
Renamed to merging_probs. Keep for backward compatibility.
pad_offsets : torch.Tensor, default = None
Tensor of per-expert cumulative padding offsets used to remove padding added
during permutation. This is the fourth output of `moe_permute_and_pad_with_probs`
and is required when unpermuting padded outputs.
"""
if probs is not None:
if merging_probs is not None:
......@@ -616,7 +692,9 @@ def moe_unpermute(
if map_type == "index":
return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs)
if map_type == "mask":
return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape)
return _moe_unpermute_mask_map.apply(
inp, row_id_map, merging_probs, restore_shape, pad_offsets
)
raise ValueError("map_type should be one of 'mask' or 'index'")
......@@ -733,11 +811,11 @@ def moe_sort_chunks_by_index(
Parameters
----------
inp: torch.Tensor
inp : torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
split_sizes: torch.Tensor
split_sizes : torch.Tensor
Chunk sizes of the inp tensor along the 0-th dimension.
sorted_indices: torch.Tensor
sorted_indices : torch.Tensor
Chunk indices used to permute the chunks.
"""
output, _ = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, None)
......@@ -757,15 +835,15 @@ def moe_sort_chunks_by_index_with_probs(
Parameters
----------
inp: torch.Tensor
inp : torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor
probs : torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens]. It will be permuted with the tokens according to
the split_sizes and sorted_indices.
split_sizes: torch.Tensor
split_sizes : torch.Tensor
Chunk sizes of the inp tensor along the 0-th dimension.
sorted_indices: torch.Tensor
sorted_indices : torch.Tensor
Chunk indices used to permute the chunks.
"""
output, permuted_probs = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, probs)
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "pip", "torch>=2.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -26,7 +26,6 @@ from transformer_engine.common.recipe import (
NVFP4BlockScaling,
CustomRecipe,
)
from .constants import dist_group_type
from .utils import (get_device_compute_capability, is_gfx928, is_gfx936, is_gfx938)
from .jit import jit_fuser
......@@ -43,6 +42,7 @@ __all__ = [
"is_fp8_block_scaling_available",
"is_nvfp4_available",
"get_default_recipe",
"get_align_size_for_quantization",
]
@functools.lru_cache(maxsize=None)
......@@ -131,6 +131,15 @@ def get_default_recipe() -> Recipe:
return get_default_fp8_recipe()
def get_align_size_for_quantization(recipe: Recipe) -> int:
"""Get the alignment size for quantization."""
if recipe.mxfp8():
return 32
if recipe.nvfp4():
return 128
return 16
def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
......@@ -685,7 +694,7 @@ def fp8_model_init(
.. warning::
fp8_model_init is deprecated and will be removed in a future release. Use
quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...) instead.
``quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...)`` instead.
"""
......@@ -730,7 +739,7 @@ def quantized_model_init(
Parameters
----------
enabled: bool, default = `True`
enabled : bool, default = True
when enabled, Transformer Engine modules created inside this `quantized_model_init`
region will hold only quantized copies of its parameters, as opposed to the default
behavior where both higher precision and quantized copies are present. Setting this
......@@ -741,9 +750,9 @@ def quantized_model_init(
precision copies of weights are already present in the optimizer.
* inference, where only the quantized copies of the parameters are used.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
recipe: transformer_engine.common.recipe.Recipe, default = `None`
recipe : transformer_engine.common.recipe.Recipe, default = None
Recipe used to create the parameters. If left to None, it uses the default recipe.
preserve_high_precision_init_val: bool, default = `False`
preserve_high_precision_init_val : bool, default = False
when enabled, store the high precision tensor used to initialize quantized parameters
in CPU memory, and add two function attributes named `get_high_precision_init_val()`
and `clear_high_precision_init_val()` to quantized parameters to get/clear this high
......@@ -780,8 +789,8 @@ def fp8_autocast(
"""
.. warning::
fp8_autocast is deprecated and will be removed in a future release.
Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead.
``fp8_autocast`` is deprecated and will be removed in a future release.
Use ``autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...)`` instead.
"""
......@@ -835,16 +844,16 @@ def autocast(
Parameters
----------
enabled: bool, default = `True`
enabled : bool, default = True
whether or not to enable low precision quantization (FP8/FP4).
calibrating: bool, default = `False`
calibrating : bool, default = False
calibration mode allows collecting statistics such as amax and scale
data of quantized tensors even when executing without quantization enabled.
This is useful for saving an inference ready checkpoint while training
using a higher precision.
recipe: recipe.Recipe, default = `None`
recipe : recipe.Recipe, default = None
recipe used for low precision quantization.
amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
amax_reduction_group : torch._C._distributed_c10d.ProcessGroup, default = None
distributed group over which amaxes for the quantized tensors
are reduced at the end of each training step.
"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -7,7 +7,6 @@
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc
import copy
import warnings
import math
......@@ -21,14 +20,9 @@ from transformer_engine.pytorch.tensor._quantization_helpers import (
_stride_from_shape,
)
_quantized_tensor_cpu_supported_ops = (
torch.ops.aten.empty_like.default,
torch.ops.aten.copy_.default,
)
class QuantizedTensorStorage:
r"""Base class for all *TensorStorage classes.
r"""Base class for all TensorStorage classes.
This class (and its subclasses) are optimization for when
the full QuantizedTensor is not needed (when it is fully
......@@ -55,11 +49,11 @@ class QuantizedTensorStorage:
Parameters
----------
rowwise_usage : Optional[bool[, default = `None`
rowwise_usage : Optional[bool[, default = None
Whether to create or keep the data needed for using the tensor
in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as `None`
preserves the original value in the tensor.
columnwise_usage : Optional[bool], default = `None`
columnwise_usage : Optional[bool], default = None
Whether to create or keep the data needed for using the tensor
in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as
`None` preserves the original value in the tensor.
......@@ -129,7 +123,7 @@ def prepare_for_saving(
]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save
the internal *TensorStorage types too."""
the internal TensorStorage types too."""
tensor_list, tensor_objects_list = [], []
for tensor in tensors:
......@@ -205,10 +199,21 @@ class Quantizer(abc.ABC):
"""
internal: bool
"""Whether to solely optimize for matrix multiplication
The resulting quantized tensors are not guaranteed to support any
operation other than matrix multiplication. Use with care since
this is likely to break communication, checkpointing, and many
other features.
"""
optimize_for_gemm: bool
def __init__(self, *, rowwise: bool, columnwise: bool) -> None:
self.rowwise_usage = rowwise
self.columnwise_usage = columnwise
self.internal = False
self.optimize_for_gemm = False
def __repr__(self):
return (
......@@ -297,10 +302,6 @@ class Quantizer(abc.ABC):
if columnwise is not None:
self.columnwise_usage = columnwise
def copy(self) -> Quantizer:
"""Create shallow copy"""
return copy.copy(self)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Symbolic function for ONNX export"""
raise NotImplementedError(
......@@ -324,7 +325,11 @@ class Quantizer(abc.ABC):
return False
def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument
"""Returns whether or not given tensor can be quantized"""
"""Whether tensor supports quantized all-gather
Consider a less misleading function name.
"""
return True
def get_usages(self) -> Dict[str, bool]:
......@@ -544,15 +549,6 @@ class QuantizedTensor(torch.Tensor):
if kwargs is None:
kwargs = {}
def check_if_cpu(arg):
if isinstance(cls, QuantizedTensor) and arg.device.type == "cpu":
assert (
func in _quantized_tensor_cpu_supported_ops
), f"QuantizedTensor on CPU does not support this operation: {func}"
return arg
args = tree_map(check_if_cpu, args)
# Do not force the QuantizedTensor type on the returned tensor
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
......@@ -92,24 +92,24 @@ def fused_topk_with_score_function(
Fused topk with score function router.
Parameters
----------
logits: torch.Tensor
topk: int
use_pre_softmax: bool
logits : torch.Tensor
topk : int
use_pre_softmax : bool
if enabled, the computation order: softmax -> topk
num_groups: int
num_groups : int
used in the group topk
group_topk: int
group_topk : int
used in the group topk
scaling_factor: float
score_function: str
scaling_factor : float
score_function : str
currently only support softmax and sigmoid
expert_bias: torch.Tensor
expert_bias : torch.Tensor
could be used in the sigmoid
Returns
-------
probs: torch.Tensor
routing_map: torch.Tensor
probs : torch.Tensor
routing_map : torch.Tensor
"""
if logits.dtype == torch.float64:
raise ValueError("Current TE does not support float64 router type")
......@@ -186,15 +186,15 @@ def fused_compute_score_for_moe_aux_loss(
Fused compute scores for MoE aux loss, subset of the fused_topk_with_score_function.
Parameters
----------
logits: torch.Tensor
topk: int
score_function: str
logits : torch.Tensor
topk : int
score_function : str
currently only support softmax and sigmoid
Returns
-------
routing_map: torch.Tensor
scores: torch.Tensor
routing_map : torch.Tensor
scores : torch.Tensor
"""
return FusedComputeScoresForMoEAuxLoss.apply(logits, topk, score_function)
......@@ -258,18 +258,18 @@ def fused_moe_aux_loss(
Fused MoE aux loss.
Parameters
----------
probs: torch.Tensor
tokens_per_expert: torch.Tensor
probs : torch.Tensor
tokens_per_expert : torch.Tensor
the number of tokens per expert
total_num_tokens: int
total_num_tokens : int
the total number of tokens, involved in the aux loss calculation
num_experts: int
topk: int
coeff: float
num_experts : int
topk : int
coeff : float
the coefficient of the aux loss
Returns
-------
aux_loss: torch.scalar
aux_loss : torch.scalar
"""
return FusedAuxLoss.apply(probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -75,21 +75,29 @@ def get_platform():
def get_wheel_url():
"""Construct the wheel URL for the current platform."""
torch_version_raw = parse(torch.__version__)
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
nvte_version = te_version()
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version = parse(torch.version.cuda)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
# For CUDA 12 we only compile for CUDA 12.3
# to save CI time. Minor versions should be compatible.
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
if torch_cuda_version.major == 12:
torch_cuda_version = parse("12.3")
elif torch_cuda_version.major == 13:
torch_cuda_version = parse("13.0")
else:
raise ValueError(f"CUDA version {torch_cuda_version} not supported")
if os.environ.get("NVIDIA_PRODUCT_NAME", "") == "PyTorch":
torch_version = str(os.environ.get("NVIDIA_PYTORCH_VERSION"))
else:
torch_version = f"{torch.__version__}"
cuda_version = f"{torch_cuda_version.major}"
# Determine wheel URL based on CUDA version, torch version, python version and OS
......@@ -109,8 +117,10 @@ class CachedWheelsCommand(_bdist_wheel):
"""
def run(self):
"""Acts a proxy before _bdist_wheel.run() and downloads a prebuilt wheel if available."""
if FORCE_BUILD:
super().run()
return
wheel_url, wheel_filename = get_wheel_url()
print("Guessing wheel URL: ", wheel_url)
......@@ -129,10 +139,12 @@ class CachedWheelsCommand(_bdist_wheel):
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path)
return
except (urllib.error.HTTPError, urllib.error.URLError):
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
super().run()
return
if __name__ == "__main__":
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tensor class with FP8 data quantized with NxN tiles"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union
from collections.abc import Iterable
import math
from typing import Any, Optional, Tuple, Union
import torch
import transformer_engine_torch as tex
import os
from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..quantized_tensor import QuantizedTensor, Quantizer
......@@ -38,8 +38,6 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float
force_pow_2_scales: bool
block_scaling_dim: int
# Whether to produce tensors that will be used in all-gather
all_gather_usage: bool
def __init__(
self,
......@@ -50,7 +48,6 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float = 0.0,
force_pow_2_scales: bool = True,
block_scaling_dim: int = 2,
all_gather_usage: bool = False,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = tex.DType.kInt8 if int8_simulation_fp8 else fp8_dtype
......@@ -58,7 +55,22 @@ class Float8BlockQuantizer(Quantizer):
self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon
self.block_scaling_dim = block_scaling_dim
self.all_gather_usage = all_gather_usage
def copy(self) -> Float8BlockQuantizer:
"""Create shallow copy"""
quantizer = Float8BlockQuantizer(
fp8_dtype=self.dtype,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
block_scaling_dim=self.block_scaling_dim,
amax_epsilon=self.amax_epsilon,
force_pow_2_scales=self.force_pow_2_scales,
)
quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
return quantizer
def update_quantized(
self,
......@@ -110,103 +122,86 @@ class Float8BlockQuantizer(Quantizer):
return tex.quantize(tensor, self)
def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]:
"""Calculate the shape of the scaling tensor for blockwise quantization.
"""Scaling tensor shape.
This method determines the shape of the scaling tensor needed for blockwise quantization,
taking into account the input tensor shape and whether columnwise scaling is used.
The scales are padded to multiples of 4 on the inner dimension for compatibility with GEMM.
This method determines the shape of the scaling tensor based
on the quantizer configuration. The scales are padded to
multiples of 4 for compatibility with GEMM.
Parameters
----------
shape : Iterable[int]
Shape of the input tensor to be quantized
Logical tensor shape.
columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False)
Whether the data is scaled column-wise (True) or row-wise (False).
Returns
-------
Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim)
For 2D tensors:
- If columnwise: (roundup(K/blocksize), round_to_multiple(roundup(M/blocksize), 4))
- If rowwise: (roundup(M/blocksize), round_to_multiple(roundup(K/blocksize), 4))
For 1D tensors:
- If columnwise: (roundup(M/blocksize), round_to_multiple(K, 4))
- If rowwise: (roundup(K/blocksize), round_to_multiple(M, 4))
Scaling tensor shape.
"""
M, K = 1, 1
for i in range(len(shape) - 1):
M *= shape[i]
if len(shape) > 0:
K = shape[-1]
# 2D 128x128 quantization block scaling
# CuBLAS requries 128x128 scaling factor to be padded
# currently rowwise and columnwise format option doesn't apply to 2D scaling
# Flatten tensor to 2D
dim0 = math.prod(shape[:-1])
dim1 = shape[-1] if shape else 1
# Check block dims
if self.block_scaling_dim not in (1, 2):
raise RuntimeError(
"Only 1D or 2D blocks are supported, "
f"but got block_scaling_dim={self.block_scaling_dim}"
)
# 128x128 block scaling
if self.block_scaling_dim == 2:
scale_dim0 = (dim0 + self.block_len - 1) // self.block_len
scale_dim1 = (dim1 + self.block_len - 1) // self.block_len
if columnwise:
outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4)
return (outer, inner)
# rowwise
outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4)
return (outer, inner)
# 1D 1x128 quantization block scaling
# CuBLAS requries 1x128 scaling factor to be padded and transposed
assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported"
return (scale_dim1, round_up_to_nearest_multiple(scale_dim0, 4))
return (scale_dim0, round_up_to_nearest_multiple(scale_dim1, 4))
# 1x128 block scaling
if columnwise:
columnwise_compact = self.all_gather_usage
outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(K, 4) if not columnwise_compact else K
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS
# for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner]
# so no need to swap inner outer here
return (outer, inner)
# rowwise
rowwise_compact = self.all_gather_usage
outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(M, 4) if not rowwise_compact else M
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need
# for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here
return (outer, inner) if not rowwise_compact else (inner, outer)
return (
(dim0 + self.block_len - 1) // self.block_len,
round_up_to_nearest_multiple(dim1, 4),
)
return (
(dim1 + self.block_len - 1) // self.block_len,
round_up_to_nearest_multiple(dim0, 4),
)
def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of a tensor after columnwise permutation.
"""Column-wise data shape
This method rearranges the dimensions of a tensor to be columnwise,
moving the last dimension to the front and keeping the order of other dimensions.
GEMMs expect that the column-wise data is transposed relative
to the logical tensor shape.
Parameters
----------
shape : Iterable[int]
Original shape of the tensor
Logical tensor shape.
Returns
-------
Tuple[int, ...]
New shape with dimensions rearranged for columnwise layout.
For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1).
Returns empty tuple for empty input shape.
Column-wise data shape.
"""
if len(shape) == 0:
return tuple()
# currently columnwise format option only applies to 1D quantizer
# for 2D scaling, columnwise format should always be GEMM_READY_DATA_AND_SCALES
# since currently 2D scaling only applies to module weights
if self.block_scaling_dim == 1 and self.all_gather_usage:
return shape
colwise_shape = [shape[-1]]
for i in range(len(shape) - 1):
colwise_shape.append(shape[i])
colwise_shape = []
if shape:
colwise_shape.append(shape[-1])
colwise_shape.extend(shape[:-1])
return tuple(colwise_shape)
def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized"""
if inp.ndim < 2:
shape = inp.size()
if len(shape) < 2:
return False
if inp.shape[-1] % self.block_len != 0:
if shape[-1] % self.block_len != 0:
return False
if math.prod(inp.shape[:-1]) % self.block_len != 0:
if math.prod(shape[:-1]) % self.block_len != 0:
return False
return True
......@@ -220,44 +215,36 @@ class Float8BlockQuantizer(Quantizer):
pin_memory: bool = False,
) -> Float8BlockwiseQTensor:
"""Construct quantized tensor with uninitialized data"""
if device is None:
device = torch.device("cuda")
data_format = (
tex.Float8BlockScaleTensorFormat.COMPACT
if self.all_gather_usage
else tex.Float8BlockScaleTensorFormat.GEMM_READY
)
tensor_kwargs = {
"device": torch.device("cuda") if device is None else device,
"pin_memory": pin_memory,
}
# Allocate FP8 data
data = None
scale_inv = None
# Allocate buffers for row-scaled data
rowwise_data = None
rowwise_scale_inv = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty(
scale_shape,
rowwise_data = torch.empty(shape, dtype=torch.uint8, **tensor_kwargs)
rowwise_scale_inv = torch.empty(
self.get_scale_shape(shape, columnwise=False),
dtype=torch.float32,
device=device,
pin_memory=pin_memory,
**tensor_kwargs,
)
# Allocate FP8 data transpose if needed
# Allocate buffers for column-scaled data
columnwise_data = None
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty(
self.get_columnwise_shape(shape),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
**tensor_kwargs,
)
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty(
columnwise_scale_shape,
self.get_scale_shape(shape, columnwise=True),
dtype=torch.float32,
device=device,
pin_memory=pin_memory,
**tensor_kwargs,
)
# Construct FP8 tensor
......@@ -265,13 +252,12 @@ class Float8BlockQuantizer(Quantizer):
shape=shape,
dtype=dtype,
fp8_dtype=self.dtype,
rowwise_data=data,
rowwise_scale_inv=scale_inv,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
quantizer=self,
is_2D_scaled=self.block_scaling_dim == 2,
data_format=data_format,
requires_grad=requires_grad,
)
......@@ -294,18 +280,18 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
Parameters
----------
rowwise_data: torch.Tensor
rowwise_data : torch.Tensor
FP8 data in a uint8 tensor matching shape of dequantized tensor.
rowwise_scale_inv: torch.Tensor
rowwise_scale_inv : torch.Tensor
FP32 dequantization scales in GEMM format for dequantizing rowwise_data.
columnwise_data: Optional[torch.Tensor]
columnwise_data : Optional[torch.Tensor]
FP8 data in a uint8 tensor matching shape of dequantized tensor transpose.
columnwise_scale_inv: Optional[torch.Tensor]
columnwise_scale_inv : Optional[torch.Tensor]
FP32 dequantization scales in GEMM format for dequantizing columnwise_data.
fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3
fp8_dtype : transformer_engine_torch.DType, default = kFloat8E4M3
FP8 format.
quantizer: Quantizer - the Float8BlockQuantizer that quantized this tensor and
quantizer : Quantizer - the Float8BlockQuantizer that quantized this tensor and
holds configuration about quantization and dequantization modes.
"""
......@@ -321,7 +307,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
fp8_dtype: TE_DType,
quantizer: Quantizer,
is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY,
**kwargs,
):
instance = super().__new__(
......@@ -333,7 +318,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
fp8_dtype,
quantizer,
is_2D_scaled,
data_format,
*args,
**kwargs,
)
......@@ -344,8 +328,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
f" is_2D_scaled={self._is_2D_scaled},"
f" data={self.dequantize(dtype=self.dtype)}),"
f" data_format={self._data_format}"
f" data={self.dequantize(dtype=self.dtype)})"
)
def quantize_(
......@@ -496,7 +479,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dtype: torch.dtype,
quantizer: Quantizer,
is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat,
data_format: Any = None, # pylint: disable=unused-argument
) -> Float8BlockwiseQTensor:
"""Build Float8BlockwiseQTensor, for use in __reduce__
......@@ -514,7 +497,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dtype=dtype,
quantizer=quantizer,
is_2D_scaled=is_2D_scaled,
data_format=data_format,
)
def __reduce_ex__(self, protocol: int) -> tuple:
......@@ -531,7 +513,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
self.dtype,
self._quantizer,
self._is_2D_scaled,
self._data_format,
None, # data_format
),
)
......@@ -557,7 +539,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv
dst._data_format = src._data_format
# Check that tensor dimensions match
if (
......@@ -605,13 +586,6 @@ class _ViewFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
......@@ -680,14 +654,6 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_data = (
grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None
)
......@@ -727,13 +693,6 @@ class _ReshapeFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
......@@ -801,14 +760,6 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_rowwise_data = None
new_columnwise_data = None
if grad._rowwise_data is not None:
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -67,6 +67,20 @@ class Float8Quantizer(Quantizer):
self.amax = amax
self.dtype = fp8_dtype
def copy(self) -> Float8Quantizer:
"""Create shallow copy"""
quantizer = Float8Quantizer(
scale=self.scale,
amax=self.amax,
fp8_dtype=self.dtype,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
)
quantizer.internal = self.internal
return quantizer
def update_quantized(
self,
src: torch.Tensor,
......@@ -246,10 +260,16 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax_reduction_group: Optional[dist_group_type] = None,
force_pow_2_scales: bool = False,
amax_epsilon: float = 0.0,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.scale = torch.empty(1, dtype=torch.float32, device=device)
self.amax = torch.empty(1, dtype=torch.float32, device=device)
if scale is None:
scale = torch.empty(1, dtype=torch.float32, device=device)
if amax is None:
amax = torch.empty(1, dtype=torch.float32, device=device)
self.scale = scale
self.amax = amax
self.dtype = tex.DType.kInt8 if int8_simulation_fp8_tensorwise else fp8_dtype
self.use_existing_amax = use_existing_amax
self.with_amax_reduction = with_amax_reduction
......@@ -257,6 +277,27 @@ class Float8CurrentScalingQuantizer(Quantizer):
self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon
def copy(self) -> Float8CurrentScalingQuantizer:
"""Create shallow copy"""
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=self.dtype,
device=0,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
with_amax_reduction=self.with_amax_reduction,
amax_reduction_group=self.amax_reduction_group,
use_existing_amax=self.use_existing_amax,
force_pow_2_scales=self.force_pow_2_scales,
amax_epsilon=self.amax_epsilon,
scale=self.scale,
amax=self.amax,
)
quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
return quantizer
def update_quantized(
self,
src: torch.Tensor,
......@@ -414,23 +455,23 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
Parameters
----------
shape: int or iterable of int
shape : int or iterable of int
Tensor dimensions.
dtype: torch.dtype
dtype : torch.dtype
Nominal tensor datatype.
requires_grad: bool, optional = False
requires_grad : bool, optional = False
Whether to compute gradients for this tensor.
data: torch.Tensor
data : torch.Tensor
Raw FP8 data in a uint8 tensor
fp8_scale_inv: torch.Tensor
fp8_scale_inv : torch.Tensor
Reciprocal of the scaling factor applied when casting to FP8,
i.e. the scaling factor that must be applied when casting from
FP8 to higher precision.
fp8_dtype: transformer_engine_torch.DType
fp8_dtype : transformer_engine_torch.DType
FP8 format.
data_transpose: torch.Tensor, optional
data_transpose : torch.Tensor, optional
FP8 transpose data in a uint8 tensor
quantizer: Float8Quantizer, Float8CurrentScalingQuantizer, optional
quantizer : Float8Quantizer, Float8CurrentScalingQuantizer, optional
Builder class for FP8 tensors
"""
......@@ -454,10 +495,10 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
# Convert PyTorch dtype to TE dtype
if dtype is None:
dtype = self.dtype
tensor = self.contiguous()
if torch.is_grad_enabled():
return _FromFloat8Func.apply(self, dtype)
return _FromFloat8Func.forward(None, self, dtype)
return _FromFloat8Func.apply(tensor, dtype)
return _FromFloat8Func.forward(None, tensor, dtype)
def quantize_(
self,
......@@ -512,18 +553,31 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
) -> Float8Tensor:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
Returns ``self`` if data is already in correct memory format.
"""
if self._data is not None and self._data.is_contiguous(memory_format=memory_format):
return self
if self._transpose is not None and self._transpose.is_contiguous(
# Check if tensor already has correct memory format
if self._data is not None and not self._data.is_contiguous(memory_format=memory_format):
pass
elif self._transpose is not None and not self._transpose.is_contiguous(
memory_format=memory_format
):
pass
else:
# Tensor has correct memory format, so return immediately
return self
return Float8Tensor.make_like(tensor=self, data=self._data.contiguous())
# raise ValueError("Float8Tensor does not support different memory formats!")
# Construct tensor with correct data format
data, data_transpose = None, None
if self._data is not None:
data = self._data.contiguous(memory_format=memory_format)
if self._transpose is not None and not self._transpose_invalid:
data_transpose = self._transpose.contiguous(memory_format=memory_format)
return _IdentityFunc.apply(
self,
{"data": data, "data_transpose": data_transpose},
)
def _reset_caches(self) -> None:
"""
......@@ -674,9 +728,8 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
[transpose, t_shape] + list(args[2:]),
kwargs,
)
# deep copy the scale inverse tensor and quantizer as well.
scale_inv = tensor._scale_inv.detach().clone()
quantizer = tensor._quantizer.copy()
quantizer = tensor._quantizer # Deep-copied in constructor
out_tensor = Float8Tensor(
data=func_out,
shape=func_out.shape,
......@@ -781,7 +834,7 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
# sure that updated Quantized weight tensor have same scale inverse across all shards.
self._quantizer.amax_reduction_group = mesh.get_group()
self._quantizer.with_amax_reduction = True
quantizer = self._quantizer.copy() # quantizer to be used for allgathered weights
fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
# If weights are resharded after forward pass, then its enough to set the quantizer usages
......@@ -794,9 +847,13 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# In case of hopper/L40, only one of data/transpose is needed
# based on forward or backward pass. So setting the quantizer usages appropriately.
quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass)
rowwise_usage = not is_backward_pass
columnwise_usage = is_backward_pass
else:
rowwise_usage = True
columnwise_usage = self._quantizer.columnwise_usage
sharded_tensors = (self._data,)
metadata = (self._scale_inv, self._fp8_dtype, quantizer)
metadata = (self._scale_inv, rowwise_usage, columnwise_usage, self._fp8_dtype)
return sharded_tensors, metadata
def fsdp_post_all_gather(
......@@ -822,7 +879,7 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
"""
(data,) = all_gather_outputs
(fp8_scale_inv, fp8_dtype, quantizer) = metadata
(fp8_scale_inv, rowwise_usage, columnwise_usage, fp8_dtype) = metadata
orig_shape = data.size()
# Quantizer has only columnwise usage set for backward pass
# In Blackwell+ architectures, transpose is not needed at all,
......@@ -831,20 +888,27 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
if out is not None:
out._data = data
else:
# We ll be here when post all gather is called the first time.
# Float8Tensor constructor makes a copy of the quantizer to
# save as its own quantizer. For the consequent iterations,
# the same quantizer is used. Copy is needed in the first iteration,
# since we need different quantizers for sharded and allgathered tensors.
# and self._quantizer belongs to the sharded parameter.
fp8_args = {
"shape": orig_shape,
"dtype": param_dtype,
"fp8_scale_inv": fp8_scale_inv,
"fp8_dtype": fp8_dtype,
"quantizer": quantizer,
"quantizer": self._quantizer,
"requires_grad": False,
"data": data,
}
out = Float8Tensor(**fp8_args)
out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage)
out.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
rowwise_usage=rowwise_usage,
columnwise_usage=columnwise_usage,
)
return out, all_gather_outputs
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -45,6 +45,19 @@ class MXFP8Quantizer(Quantizer):
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = fp8_dtype
def copy(self) -> MXFP8Quantizer:
"""Create shallow copy"""
quantizer = MXFP8Quantizer(
fp8_dtype=self.dtype,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
)
quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
return quantizer
def update_quantized(
self,
src: torch.Tensor,
......@@ -122,7 +135,9 @@ class MXFP8Quantizer(Quantizer):
columnwise_data = None
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty_like(data, pin_memory=pin_memory)
columnwise_data = torch.empty(
shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
)
columnwise_scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128),
......@@ -142,6 +157,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv=columnwise_scale_inv,
quantizer=self,
requires_grad=requires_grad,
with_gemm_swizzled_scales=self.optimize_for_gemm,
)
def calibrate(self, tensor: torch.Tensor) -> None:
......@@ -165,6 +181,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv=None,
fp8_dtype=fp8_dtype,
quantizer=self,
with_gemm_swizzled_scales=False,
)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
......@@ -174,6 +191,10 @@ class MXFP8Quantizer(Quantizer):
return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32)
def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> torch.Tensor:
if tensor._with_gemm_swizzled_scales:
raise NotImplementedError(
"ONNX MXFP8 dequantization is only supported with scales in compact format."
)
return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
......@@ -190,16 +211,16 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
Parameters
----------
data: torch.Tensor
data : torch.Tensor
Raw FP8 data in a uint8 tensor
fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3
fp8_dtype : transformer_engine_torch.DType, default = kFloat8E4M3
FP8 format.
fp8_scale_inv: torch.Tensor
fp8_scale_inv : torch.Tensor
Reciprocal of the scaling factor applied when
casting to FP8, i.e. the scaling factor that must
be applied when casting from FP8 to higher
precision.
dtype: torch.dtype, default = torch.float32
dtype : torch.dtype, default = torch.float32
Nominal tensor datatype.
"""
......@@ -215,9 +236,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType,
quantizer: Optional[Quantizer],
with_gemm_swizzled_scales: bool,
**kwargs,
):
instance = super().__new__(
return super().__new__(
cls,
rowwise_data,
rowwise_scale_inv,
......@@ -225,10 +247,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
columnwise_scale_inv,
fp8_dtype,
quantizer,
with_gemm_swizzled_scales,
*args,
**kwargs,
)
return instance
def __repr__(self, *, tensor_contents=None):
return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})"
......@@ -320,39 +342,44 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# View op
if func == aten.view.default:
tensor = args[0]
data = tensor._rowwise_data
out_data = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
out_shape = out_data.size()
shape = args[1]
if len(shape) < 2 or shape[-1] != tensor.size(-1):
raise ValueError(
f"Attempted to make view with size={tuple(shape)} "
f"from MXFP8 tensor with shape={tuple(tensor.size())}."
)
rowwise_data_view = None
columnwise_data_view = None
if tensor._rowwise_data is not None:
rowwise_data_view = tensor._rowwise_data.view(shape)
if tensor._columnwise_data is not None:
columnwise_data_view = tensor._columnwise_data.view(shape)
return MXFP8Tensor(
shape=out_shape,
shape=shape,
dtype=tensor.dtype,
rowwise_data=out_data,
rowwise_data=rowwise_data_view,
rowwise_scale_inv=tensor._rowwise_scale_inv,
columnwise_data=tensor._columnwise_data,
columnwise_data=columnwise_data_view,
columnwise_scale_inv=tensor._columnwise_scale_inv,
quantizer=tensor._quantizer,
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
)
if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1]
if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor):
# Booleans to check if src has all the usages that dst needs to respect dst quantizer usages.
# If not, default to base class behavior.
rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None
columnwise_matches = (
src._columnwise_data is not None or dst._columnwise_data is None
)
if rowwise_matches and columnwise_matches:
if src._rowwise_data is None and dst._rowwise_data is not None:
pass
elif src._columnwise_data is None and dst._columnwise_data is not None:
pass
elif src._with_gemm_swizzled_scales != dst._with_gemm_swizzled_scales:
pass
else:
# src and dst match, so we can directly copy data
if dst._rowwise_data is not None:
dst._rowwise_data.copy_(src._rowwise_data.detach(), *args[2:], **kwargs)
dst._rowwise_scale_inv.copy_(
......@@ -367,26 +394,25 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
)
return dst
# FSDP2 related functions.
if func == aten.split.Tensor:
# This is called if entire model is initialized on CUDA device and
# then splitted. Finally the shard needed by the process is used
# and other splitted shards are discarded.
# With FSDP2, this is called if entire model is
# initialized on CUDA device and then splitted. Finally
# the shard needed by the process is used and other
# splitted shards are discarded.
tensor = args[0]
split_size = args[1]
if "dim" in kwargs:
dim_to_split = kwargs["dim"]
else:
dim_to_split = args[2] if len(args) > 2 else 0
tensor = args[0]
split_size = args[1]
dim0_size = tensor.size(0)
dimlast_size = math.prod(tensor.shape[1:])
# Fall back to high-precision if split is non-trivial
if (
dim0_size % split_size != 0
or dim_to_split != 0
dim_to_split != 0
or tensor.size(0) % split_size != 0
or split_size % MXFP8_BLOCK_SCALING_SIZE != 0
or dimlast_size % MXFP8_BLOCK_SCALING_SIZE != 0
or tensor._with_gemm_swizzled_scales
):
# Handle splitting by dequantizing and splitting the hp tensor
return super().__torch_dispatch__(func, types, args, kwargs)
out_data = []
......@@ -420,13 +446,16 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
if scale_inv is not None
else None
)
scale_inv_out = list(scale_inv_out) if scale_inv_out is not None else None
# Pad scale_inv_out to be a multiple of pad_multiple
if scale_inv_out is not None:
current_shape = scale_inv_out.shape
pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple
if pad_dim0 > 0:
scale_inv_out = torch.nn.functional.pad(scale_inv_out, (0, 0, 0, pad_dim0))
for idx, split_scale_inv_out in enumerate(scale_inv_out):
current_shape = split_scale_inv_out.shape
pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple
if pad_dim0 > 0:
scale_inv_out[idx] = torch.nn.functional.pad(
split_scale_inv_out, (0, 0, 0, pad_dim0)
)
out_data.append(scale_inv_out)
return [
MXFP8Tensor(
......@@ -443,28 +472,26 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
quantizer=tensor._quantizer,
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
with_gemm_swizzled_scales=False,
)
for splitted_tensor_data in zip(*out_data)
]
if func == torch.ops.aten.as_strided.default:
# Applied on unsharded param in FSDP2. In our case, this should be a no-op
# This is needed for the case where some MXFP8 shards need padding i.e dimension 0
# of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision.
# If weight doesnt need padding, this is just a no-op.
tensor = args[0]
shape = args[1]
strides = args[2]
tensor = args[0]
if (
len(shape) != 2
or len(strides) != 2
or strides[1] != 1
or shape[0] != tensor.shape[0]
or shape[1] != tensor.shape[1]
len(shape) == len(strides) == 2
and tuple(strides) == (shape[-1], 1)
and tuple(shape) == tuple(tensor.size())
):
return super().__torch_dispatch__(func, types, args, kwargs)
return MXFP8Tensor.make_like(tensor)
return MXFP8Tensor.make_like(tensor)
if func == aten.slice.Tensor:
# FSDP2 needed function.
......@@ -472,19 +499,12 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
# of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision instead.
# If sharded weight doesnt have padding, this is just a no-op.
tensor = args[0]
dim = args[1]
start = args[2]
length = args[3]
tensor = args[0]
if (
dim != 0
or length != tensor.shape[0]
or start != 0
or length % MXFP8_BLOCK_SCALING_SIZE != 0
or start % MXFP8_BLOCK_SCALING_SIZE != 0
):
return super().__torch_dispatch__(func, types, args, kwargs)
return MXFP8Tensor.make_like(tensor)
if start == 0 and length == tensor.size(dim):
return MXFP8Tensor.make_like(tensor)
if func == aten.new_zeros.default:
rowwise_data = None
......@@ -538,10 +558,12 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
quantizer=tensor._quantizer.copy(),
quantizer=tensor._quantizer,
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
)
# Default case
return super().__torch_dispatch__(func, types, args, kwargs)
......@@ -567,29 +589,32 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
# pylint: disable=unused-argument
from transformer_engine.pytorch.distributed import _get_module_fsdp_state
# Get FSDP state
fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
quantizer = self._quantizer.copy()
# Remove padding from scale inverses before allgather
# Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128]
rowwise_scale_inv = self._rowwise_scale_inv
columnwise_scale_inv = self._columnwise_scale_inv
shape = self.shape
if self._with_gemm_swizzled_scales:
raise NotImplementedError(
"FSDP2 is only supported for MXFP8Tensors with compact scales"
)
if rowwise_scale_inv is not None:
# Remove padding from rowwise scale_inv
flattened_in_shape0 = math.prod(shape[:-1])
if rowwise_scale_inv.size(0) != flattened_in_shape0:
rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0]
if columnwise_scale_inv is not None:
# Remove padding from columnwise scale_inv
flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE
if columnwise_scale_inv.size(0) != flattened_in_shape0:
columnwise_scale_inv = columnwise_scale_inv[:flattened_in_shape0]
sharded_tensors = (self._rowwise_data, rowwise_scale_inv)
# If weights are resharded after forward pass, then its enough to set the quantizer usages
# based on whether its forward or backward pass for the allgathered weights.
# If weights are resharded after forward pass, then its enough to send one row/col
# usage based on whether its forward or backward pass for the allgathered weights.
# If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward. And hence if we need the columnwise data/scale_inv,
# we need to send them as well for allgather in forward pass itself.
......@@ -597,18 +622,24 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
training_state = fsdp_state._fsdp_param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# Allgather only the necessary tensors based on forward/backward pass
quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass)
rowwise_usage = not is_backward_pass
columnwise_usage = is_backward_pass
sharded_tensors = (
(self._columnwise_data, columnwise_scale_inv)
if is_backward_pass
else sharded_tensors
else (self._rowwise_data, rowwise_scale_inv)
)
else:
if quantizer.columnwise_usage:
# rowwise usage is always needed for forward pass.
rowwise_usage = True
sharded_tensors = (self._rowwise_data, rowwise_scale_inv)
columnwise_usage = self._quantizer.columnwise_usage
if columnwise_usage:
# If weights are not resharded after forward, then both
# rowwise and columnwise data/scale_inv need to be allgathered.
sharded_tensors += (self._columnwise_data, columnwise_scale_inv)
metadata = (self._fp8_dtype, quantizer)
metadata = (self._fp8_dtype, rowwise_usage, columnwise_usage)
return sharded_tensors, metadata
def fsdp_post_all_gather(
......@@ -631,12 +662,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
Tuple[MXFP8Tensor, Tuple[torch.Tensor, ...]]: Allgathered MXFP8Tensor and tuple of internal tensors
used by the MXFP8Tensor that was being computed after allgather.
"""
fp8_dtype, quantizer = metadata
rowwise_data, rowwise_scale_inv = (
all_gather_outputs[:2] if quantizer.rowwise_usage else (None, None)
)
fp8_dtype, rowwise_usage, columnwise_usage = metadata
rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] if rowwise_usage else (None, None)
columnwise_data, columnwise_scale_inv = (
all_gather_outputs[-2:] if quantizer.columnwise_usage else (None, None)
all_gather_outputs[-2:] if columnwise_usage else (None, None)
)
# Add padding to scale_inv tensors to be multiples of [128, 4]for rowwise and [4, 128] for columnwise
......@@ -661,8 +690,13 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
out._rowwise_scale_inv = rowwise_scale_inv
out._columnwise_data = columnwise_data
out._columnwise_scale_inv = columnwise_scale_inv
out._quantizer = quantizer
else:
# We'll be here when post all gather is called the first time.
# MXFP8Tensor constructor makes a copy of the quantizer to
# save as its own quantizer. For the consequent iterations,
# the same quantizer is used. Copy is needed in the first iteration,
# since we need different quantizers for sharded and allgathered tensors.
# and self._quantizer belongs to the sharded parameter.
out = MXFP8Tensor(
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
......@@ -671,9 +705,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
fp8_dtype=fp8_dtype,
dtype=param_dtype,
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
quantizer=quantizer,
quantizer=self._quantizer,
with_gemm_swizzled_scales=False,
)
out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage)
return out, all_gather_outputs
@classmethod
......@@ -687,6 +722,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
dtype: torch.dtype,
shape: torch.shape,
quantizer: Optional[Quantizer] = None,
with_gemm_swizzled_scales: bool = False,
) -> MXFP8Tensor:
"""Build MXFP8Tensor, for use in __reduce__
......@@ -703,6 +739,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
dtype=dtype,
shape=shape,
quantizer=quantizer,
with_gemm_swizzled_scales=with_gemm_swizzled_scales,
)
def __reduce_ex__(self, protocol: int) -> tuple:
......@@ -718,6 +755,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
self.dtype,
self.shape,
self._quantizer,
self._with_gemm_swizzled_scales,
),
)
......@@ -739,7 +777,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
if not devices_match(new_device, tensor.device):
tensor = tensor.to(device=new_device)
# Just copy FP8 data if other tensor is MXFP8Tensor
# Just copy data if other tensor is MXFP8Tensor
if isinstance(tensor, MXFP8Tensor):
if ( # pylint: disable=too-many-boolean-expressions
self.size() != tensor.size()
......@@ -767,6 +805,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
self._fp8_dtype = tensor._fp8_dtype
self._rowwise_scale_inv = tensor._rowwise_scale_inv
self._columnwise_scale_inv = tensor._columnwise_scale_inv
self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales
return
# Quantize to FP8
......@@ -838,6 +877,7 @@ class _ViewFunc(torch.autograd.Function):
columnwise_scale_inv=tensor._columnwise_scale_inv,
fp8_dtype=tensor._fp8_dtype,
quantizer=tensor._quantizer,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
)
@staticmethod
......@@ -864,6 +904,7 @@ class _ViewFunc(torch.autograd.Function):
columnwise_scale_inv=grad._columnwise_scale_inv,
fp8_dtype=grad._fp8_dtype,
quantizer=grad._quantizer,
with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales,
)
return dgrad, None
return grad.view(ctx.shape), None
......@@ -924,6 +965,7 @@ class _ReshapeFunc(torch.autograd.Function):
columnwise_scale_inv=tensor._columnwise_scale_inv,
fp8_dtype=tensor._fp8_dtype,
quantizer=tensor._quantizer,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
)
@staticmethod
......@@ -949,6 +991,7 @@ class _ReshapeFunc(torch.autograd.Function):
columnwise_scale_inv=grad._columnwise_scale_inv,
fp8_dtype=grad._fp8_dtype,
quantizer=grad._quantizer,
with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales,
)
return dgrad, None
return grad.view(ctx.shape), None
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -28,9 +28,9 @@ from ._quantization_helpers import _IdentityFunc
aten = torch.ops.aten
def get_no_random_sign_vector() -> torch.Tensor:
def get_no_random_sign_vector(device: int) -> torch.Tensor:
"""Non-random sign vector for Hadamard transform."""
return torch.tensor([1], dtype=torch.float32, device="cuda")
return torch.tensor([1], dtype=torch.float32, device=device)
def get_sign_from_vector(vector: torch.Tensor) -> int:
......@@ -45,7 +45,7 @@ def get_sign_from_vector(vector: torch.Tensor) -> int:
return mask.item()
def get_wgrad_sign_vector() -> torch.Tensor:
def get_wgrad_sign_vector(device: int) -> torch.Tensor:
"""Hard-coded random signs for Hadamard transform.
https://xkcd.com/221/
......@@ -54,11 +54,11 @@ def get_wgrad_sign_vector() -> torch.Tensor:
return torch.tensor(
[1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1],
dtype=torch.float32,
device="cuda",
device=device,
)
def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor:
def get_hadamard_matrix(hadamard_dimension: int, device: int) -> torch.Tensor:
"""Construct a 16x16 Hadamard matrix."""
assert hadamard_dimension == 16, "Only hadamard dimension 16 is supported."
hadamard_scale = 1 / math.sqrt(hadamard_dimension)
......@@ -83,30 +83,30 @@ def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor:
[1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1],
],
dtype=torch.float32,
device="cuda",
device=device,
)
* hadamard_scale
)
@functools.lru_cache(maxsize=None)
def get_rht_matrix(with_random_sign_mask: bool) -> torch.Tensor:
def get_rht_matrix(with_random_sign_mask: bool, device: int) -> torch.Tensor:
"""Construct matrix used in random Hadamard transform."""
hadamard_dimension = 16
if with_random_sign_mask:
signs = get_wgrad_sign_vector()
signs = get_wgrad_sign_vector(device=device)
else:
signs = get_no_random_sign_vector()
sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32, device="cuda")
rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension)
signs = get_no_random_sign_vector(device=device)
sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32, device=device)
rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension, device=device)
return rht_matrix.to(dtype=torch.bfloat16)
@functools.lru_cache(maxsize=None)
def get_random_sign_mask_for_rht(with_random_sign_mask: bool) -> int:
def get_random_sign_mask_for_rht(with_random_sign_mask: bool, device: int) -> int:
"""Sign mask for random Hadamard transform."""
if with_random_sign_mask:
return get_sign_from_vector(get_wgrad_sign_vector())
return get_sign_from_vector(get_wgrad_sign_vector(device=device))
return 0
......@@ -152,8 +152,10 @@ class NVFP4Quantizer(Quantizer):
self.amax_reduction_group = amax_reduction_group
self.with_2d_quantization = with_2d_quantization
self.stochastic_rounding = stochastic_rounding
self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht(with_random_sign_mask)
self.rht_matrix = get_rht_matrix(with_random_sign_mask)
self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht(
with_random_sign_mask, torch.cuda.current_device()
)
self.rht_matrix = get_rht_matrix(with_random_sign_mask, torch.cuda.current_device())
def update_quantized(
self,
......@@ -176,6 +178,27 @@ class NVFP4Quantizer(Quantizer):
return dst
def copy(self) -> NVFP4Quantizer:
"""Create shallow copy"""
quantizer = NVFP4Quantizer(
fp4_dtype=self.dtype,
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
with_amax_reduction=self.with_amax_reduction,
amax_reduction_group=self.amax_reduction_group,
with_rht=self.with_rht,
with_post_rht_amax=self.with_post_rht_amax,
with_2d_quantization=self.with_2d_quantization,
stochastic_rounding=self.stochastic_rounding,
)
quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
quantizer.rht_matrix = self.rht_matrix
quantizer.rht_matrix_random_sign_mask_t = self.rht_matrix_random_sign_mask_t
return quantizer
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
......@@ -337,6 +360,7 @@ class NVFP4Quantizer(Quantizer):
fp4_dtype=self.dtype,
quantizer=self,
requires_grad=requires_grad,
with_gemm_swizzled_scales=False,
)
def calibrate(self, tensor: torch.Tensor) -> None:
......@@ -360,26 +384,26 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
Parameters
----------
rowwise_data: torch.Tensor
rowwise_data : torch.Tensor
Raw FP4 data in a uint8 tensor (rowwise layout).
rowwise_scale_inv: torch.Tensor
rowwise_scale_inv : torch.Tensor
Reciprocal of the scaling factor applied when
casting to FP4, i.e. the scaling factor that must
be applied when casting from FP4 to higher
precision (rowwise).
columnwise_data: torch.Tensor, optional
columnwise_data : torch.Tensor, optional
Raw FP4 data in a uint8 tensor (columnwise layout).
columnwise_scale_inv: torch.Tensor, optional
columnwise_scale_inv : torch.Tensor, optional
Reciprocal of the scaling factor for columnwise FP4 data.
amax_rowwise: torch.Tensor, optional
amax_rowwise : torch.Tensor, optional
Rowwise amax tracking tensor.
amax_columnwise: torch.Tensor, optional
amax_columnwise : torch.Tensor, optional
Columnwise amax tracking tensor.
fp4_dtype: TE_DType
fp4_dtype : TE_DType
The FP4 data type used for quantization.
quantizer: Quantizer
quantizer : Quantizer
The quantizer instance used for this tensor.
dtype: torch.dtype, default = torch.float32
dtype : torch.dtype, default = torch.float32
Nominal tensor datatype, used in dequantize.
"""
......@@ -396,6 +420,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise: Optional[torch.Tensor],
fp4_dtype: TE_DType,
quantizer: Quantizer,
with_gemm_swizzled_scales: bool,
**kwargs,
):
instance = super().__new__(
......@@ -408,6 +433,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise,
fp4_dtype,
quantizer,
with_gemm_swizzled_scales,
*args,
**kwargs,
)
......@@ -570,6 +596,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise=amax_columnwise,
quantizer=tensor._quantizer,
requires_grad=tensor.requires_grad,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
)
# Default case
......@@ -588,6 +615,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
fp4_dtype: TE_DType,
dtype: torch.dtype,
quantizer: Quantizer,
with_gemm_swizzled_scales: bool = False,
) -> NVFP4Tensor:
"""Build NVFP4Tensor, for use in __reduce__
......@@ -607,6 +635,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise=amax_columnwise,
quantizer=quantizer,
requires_grad=False,
with_gemm_swizzled_scales=with_gemm_swizzled_scales,
)
def __reduce_ex__(self, protocol: int) -> tuple:
......@@ -624,6 +653,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
self._fp4_dtype,
self.dtype,
self._quantizer,
self._with_gemm_swizzled_scales,
),
)
......@@ -674,6 +704,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
self._columnwise_scale_inv = tensor._columnwise_scale_inv
self._amax_rowwise = tensor._amax_rowwise
self._amax_columnwise = tensor._amax_columnwise
self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales
return
# Quantize to FP8
......@@ -760,6 +791,7 @@ class _ViewFunc(torch.autograd.Function):
quantizer=tensor._quantizer,
fp4_dtype=tensor._fp4_dtype,
requires_grad=tensor.requires_grad,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
)
@staticmethod
......@@ -801,6 +833,7 @@ class _ViewFunc(torch.autograd.Function):
quantizer=grad._quantizer,
fp4_dtype=grad._fp4_dtype,
requires_grad=grad.requires_grad,
with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales,
)
return dgrad, None
return grad.view(ctx.shape), None
......@@ -880,6 +913,7 @@ class _ReshapeFunc(torch.autograd.Function):
quantizer=tensor._quantizer,
fp4_dtype=tensor._fp4_dtype,
requires_grad=tensor.requires_grad,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
)
@staticmethod
......@@ -921,6 +955,7 @@ class _ReshapeFunc(torch.autograd.Function):
quantizer=grad._quantizer,
fp4_dtype=grad._fp4_dtype,
requires_grad=grad.requires_grad,
with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales,
)
return dgrad, None
return grad.view(ctx.shape), None
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Storage for quantized tensors."""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -11,7 +11,6 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
from ...quantized_tensor import QuantizedTensorStorage, Quantizer
......@@ -37,7 +36,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
_rowwise_scale_inv: Optional[torch.Tensor]
_columnwise_scale_inv: Optional[torch.Tensor]
_is_2D_scaled: bool
_data_format: Float8BlockScaleTensorFormat
def __new__(
cls,
......@@ -48,7 +46,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
fp8_dtype: TE_DType,
quantizer: Quantizer,
is_2D_scaled: bool,
data_format: Float8BlockScaleTensorFormat,
*args,
**kwargs,
):
......@@ -63,7 +60,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
instance._is_2D_scaled = is_2D_scaled
instance._data_format = data_format
return instance
......@@ -88,13 +84,8 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
"fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer,
"is_2D_scaled": self._is_2D_scaled,
"data_format": self._data_format,
}
def _is_gemm_ready_format(self) -> bool:
"""Whether data is in GEMM_READY format"""
return self._data_format == Float8BlockScaleTensorFormat.GEMM_READY
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorStorage]:
......@@ -154,36 +145,18 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
inner_q_dimension_tiled = True
if self._is_gemm_ready_format():
scales_tiled_dim, scales_untiled_dim = scale_inv.shape
inner_scale_dimension_tiled = False
scales_are_compact = False
else:
scales_untiled_dim, scales_tiled_dim = scale_inv.shape
inner_scale_dimension_tiled = True
scales_are_compact = True
scales_tiled_dim, scales_untiled_dim = scale_inv.shape
else:
assert self._columnwise_data is not None, "No data to dequantize"
q = self._columnwise_data
scale_inv = self._columnwise_scale_inv
scales_tiled_dim, scales_untiled_dim = scale_inv.shape
inner_scale_dimension_tiled = False
if self._is_gemm_ready_format():
inner_q_dimension_tiled = True
transpose_output = True
if len(q.shape) >= 1:
q_M = q.shape[0]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
scales_are_compact = False
else:
inner_q_dimension_tiled = False
transpose_output = False
if len(q.shape) >= 1:
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
scales_are_compact = True
inner_q_dimension_tiled = True
transpose_output = True
if len(q.shape) >= 1:
q_M = q.shape[0]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
orig_shape = q.shape
q = q.reshape(q_M, q_K)
......@@ -203,15 +176,10 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
).contiguous()
padded_M, padded_K = q.shape
q_tiled = q.reshape(scales_tiled_dim, block_len, q_K)
if not scales_are_compact and scales_untiled_dim > q_M:
if scales_untiled_dim > q_M:
# untiled scale dimension is 4 element aligned.
scale_inv = scale_inv[:, :q_M].contiguous()
if scales_are_compact and inner_scale_dimension_tiled:
dq_scale = scale_inv.contiguous().reshape(q_M, scales_tiled_dim, 1)
elif scales_are_compact and not inner_scale_dimension_tiled:
dq_scale = scale_inv.contiguous().reshape(scales_tiled_dim, 1, q_K)
else:
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, scales_tiled_dim, 1)
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, scales_tiled_dim, 1)
torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype]
result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale
if padded_M != q_M or padded_K != q_K:
......@@ -234,12 +202,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
if not self._is_2D_scaled:
return self._dequantize_vectorwise(dtype=dtype)
if not self._is_gemm_ready_format():
raise NotImplementedError(
"Dequantize is only supported with GEMM_READY data format, "
f"but found _data_format={self._data_format}"
)
def format_scale_as_logical_shape(q_K, scales, block_len):
# The GEMM for 2D blocks required padding in the scales.
derived_scale_k_shape = math.ceil(q_K / block_len)
......@@ -305,8 +267,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
if self._rowwise_data is not None:
return self._rowwise_data.size(*args, **kwargs)
dims = list(self._columnwise_data.size(*args, **kwargs))
if not self._is_gemm_ready_format(): # compact format
return torch.Size(dims)
reordered = []
for i in range(1, len(dims)):
reordered.append(dims[i])
......@@ -367,7 +327,7 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
return (
"Float8BlockwiseQTensorStorage("
f"fp8_dtype={self._fp8_dtype}, "
f"{descriptor}_scaled_data={data}"
f"{descriptor}_scaled_data={data})"
)
def update_usage(
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -57,13 +57,23 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
"""
# Row-scaled FP8 data
_rowwise_data: Optional[torch.Tensor]
# Column-scaled FP8 data
_columnwise_data: Optional[torch.Tensor]
_quantizer: Optional[Quantizer]
_fp8_dtype: TE_DType
# Scaling factors for row-scaled FP8 data
_rowwise_scale_inv: torch.Tensor
# Scaling factors for column-scaled FP8 data
_columnwise_scale_inv: torch.Tensor
# Builder class for casting to MXFP8
_quantizer: Optional[Quantizer]
# FP8 data type
_fp8_dtype: TE_DType
# Whether scaling factors are in the swizzled format expected by
# GEMM
_with_gemm_swizzled_scales: bool
def __new__(
cls,
rowwise_data: Optional[torch.Tensor],
......@@ -72,6 +82,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType,
quantizer: Optional[Quantizer],
with_gemm_swizzled_scales: bool,
*args,
**kwargs,
):
......@@ -81,10 +92,11 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales
return instance
......@@ -108,6 +120,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
"columnwise_scale_inv": self._columnwise_scale_inv,
"fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer,
"with_gemm_swizzled_scales": self._with_gemm_swizzled_scales,
}
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]:
......@@ -197,6 +210,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
columnwise_scale_inv=self._columnwise_scale_inv,
fp8_dtype=self._fp8_dtype,
quantizer=self._quantizer,
with_gemm_swizzled_scales=self._with_gemm_swizzled_scales,
)
def __repr__(self):
......@@ -255,7 +269,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
self._columnwise_data = None
self._columnwise_scale_inv = None
def get_usages(self) -> Tuple[bool, bool]:
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
return {
"rowwise": self._rowwise_data is not None,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -71,15 +71,29 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
"""
# Row-scaled FP4 data
_rowwise_data: Optional[torch.Tensor]
# Column-scaled FP4 data
_columnwise_data: Optional[torch.Tensor]
_quantizer: Optional[Quantizer]
# Block scaling factors for row-scaled FP4 data
_rowwise_scale_inv: torch.Tensor
# Block scaling factors for column-scaled FP4 data
_columnwise_scale_inv: torch.Tensor
_fp4_dtype: TE_DType
# Input absolute maximum value (used to compute tensor scale for
# row-scaled FP4 data)
_amax_rowwise: torch.Tensor
# Input absolute maximum value (used to compute tensor scale for
# column-scaled FP4 data)
_amax_columnwise: torch.Tensor
# Builder class for casting to MXFP8
_quantizer: Optional[Quantizer]
# FP4 data type
_fp4_dtype: TE_DType
# Whether scaling factors are in the swizzled format expected by
# GEMM
_with_gemm_swizzled_scales: bool
def __new__(
cls,
rowwise_data: Optional[torch.Tensor],
......@@ -90,6 +104,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
amax_columnwise: torch.Tensor,
fp4_dtype: TE_DType,
quantizer: Optional[Quantizer],
with_gemm_swizzled_scales: bool,
*args,
**kwargs,
):
......@@ -104,6 +119,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
instance._columnwise_scale_inv = columnwise_scale_inv
instance._amax_rowwise = amax_rowwise
instance._amax_columnwise = amax_columnwise
instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales
return instance
......@@ -131,6 +147,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
"amax_columnwise": self._amax_columnwise,
"fp4_dtype": self._fp4_dtype,
"quantizer": self._quantizer,
"with_gemm_swizzled_scales": self._with_gemm_swizzled_scales,
}
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]:
......@@ -248,6 +265,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
amax_columnwise=self._amax_columnwise,
quantizer=self._quantizer,
fp4_dtype=self._fp4_dtype,
with_gemm_swizzled_scales=self._with_gemm_swizzled_scales,
)
def __repr__(self):
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -8,7 +8,11 @@ from typing import Optional, Union, List
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
from transformer_engine_torch import (
multi_tensor_scale,
multi_tensor_compute_scale_and_scale_inv,
multi_tensor_compute_scale_inv_e8m0,
)
from ..quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage
from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
......@@ -74,7 +78,7 @@ def cast_master_weights_to_fp8(
fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are
not sharded. Otherwise, it means that the model weights are sharded and we get
target model weights data storage using the FSDP shard model weights.
manual_post_all_gather_processing: bool, default = `False`.
manual_post_all_gather_processing : bool, default = `False`.
If False, post processing will be automatically triggered during next forward.
If True, the timing of calling post_all_gather_processing is left to the user.
Note that users must call `post_all_gather_processing` if it's set to True,
......@@ -85,6 +89,7 @@ def cast_master_weights_to_fp8(
delayed_scaling_params = []
current_scaling_params = []
blockwise_scaling_params = []
mxfp8_scaling_params = []
if fsdp_shard_model_weights is None:
use_fsdp_shard_model_weights = False
......@@ -131,8 +136,8 @@ def cast_master_weights_to_fp8(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, MXFP8Quantizer):
raise NotImplementedError(
"cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet"
mxfp8_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
else:
raise ValueError(
......@@ -146,6 +151,8 @@ def cast_master_weights_to_fp8(
_cast_master_weights_to_fp8_current_scaling(current_scaling_params, *extra_args)
if len(blockwise_scaling_params) > 0:
_cast_master_weights_to_fp8_blockwise_scaling(blockwise_scaling_params, *extra_args)
if len(mxfp8_scaling_params) > 0:
_cast_master_weights_to_fp8_mxfp8_scaling(mxfp8_scaling_params, *extra_args)
def _cast_master_weights_to_fp8_delayed_scaling(
......@@ -471,6 +478,131 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
)
def _cast_master_weights_to_fp8_mxfp8_scaling(
params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False
): # pylint: disable=unused-argument
r"""Helper function to cast master weights to FP8 primary weights for mxfp8 scaling.
Parameters
----------
params : List of tuple, each tuple contains a model weight, a master weight, and an offset
indicating the starting index of the master weight in the model weight.
group : The distributed group to do amax reduction. Typically it's the data parallel
group.
use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded.
"""
# Parameter attributes
device = params[0][0].device
for _, master_weight, _, _ in params:
if master_weight is not None:
master_weight_dtype = master_weight.dtype
break
# Get the total number of amax elements in all the model weights.
cu_rowwise_amax_sizes = [0]
cu_colwise_amax_sizes = [0]
for model_weight, _, _, _ in params:
rowwise_shape = model_weight._rowwise_scale_inv.shape
assert len(rowwise_shape) == 2
colwise_shape = model_weight._columnwise_scale_inv.shape
assert len(colwise_shape) == 2
cu_rowwise_amax_sizes.append(
cu_rowwise_amax_sizes[-1] + rowwise_shape[0] * rowwise_shape[1]
)
cu_colwise_amax_sizes.append(
cu_colwise_amax_sizes[-1] + colwise_shape[0] * colwise_shape[1]
)
# Create a contiguous buffer to store amaxes temporarily, so we can perform all all-reduce
# NCCL kernels at once.
packed_amaxes = torch.zeros(
cu_rowwise_amax_sizes[-1] + cu_colwise_amax_sizes[-1],
dtype=master_weight_dtype,
device=device,
)
# ---------------------------------------------------------------------------------------------
# Step 1: Iterate through all the none empty master weights and compute amax of them. Store the
# amaxes in a contiguous buffer. If a block of a master weight is empty, the
# corresponding amax will be set to 0.
# ---------------------------------------------------------------------------------------------
amaxes_rowwise, scale_invs_rowwise = [], []
amaxes_colwise, scale_invs_colwise = [], []
for i, (model_weight, master_weight, start_offset, _) in enumerate(params):
rowwise_shape = model_weight._rowwise_scale_inv.shape
colwise_shape = model_weight._columnwise_scale_inv.shape
rowwise_start = cu_rowwise_amax_sizes[i]
rowwise_end = cu_rowwise_amax_sizes[i + 1]
colwise_start = cu_rowwise_amax_sizes[-1] + cu_colwise_amax_sizes[i]
colwise_end = cu_rowwise_amax_sizes[-1] + cu_colwise_amax_sizes[i + 1]
amax_rowwise = packed_amaxes[rowwise_start:rowwise_end].reshape(rowwise_shape)
amax_colwise = packed_amaxes[colwise_start:colwise_end].reshape(colwise_shape)
amaxes_rowwise.append(amax_rowwise)
amaxes_colwise.append(amax_colwise)
scale_invs_rowwise.append(model_weight._rowwise_scale_inv)
scale_invs_colwise.append(model_weight._columnwise_scale_inv)
# Compute amax of the master weight and store it in packed_amaxes.
if master_weight is not None:
assert len(model_weight.shape) == 2
h, w = model_weight.shape
tex.mxfp8_scaling_compute_partial_amax(
master_weight, amax_rowwise, amax_colwise, h, w, start_offset
)
# ---------------------------------------------------------------------------------------------
# Step 2: Perform all-reduce on packed_amaxes to get the global amax.
# ---------------------------------------------------------------------------------------------
torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group)
# ---------------------------------------------------------------------------------------------
# Step 3: Update scales and scale_invs.
# ---------------------------------------------------------------------------------------------
multi_tensor_applier(
multi_tensor_compute_scale_inv_e8m0,
None, # dummy_overflow_buf
[
amaxes_rowwise + amaxes_colwise,
scale_invs_rowwise + scale_invs_colwise,
],
)
# ---------------------------------------------------------------------------------------------
# Step 4: Cast master weights to FP8.
# ---------------------------------------------------------------------------------------------
for (
(model_weight, master_weight, start_offset, model_weight_fragment),
scale_inv_rowwise,
scale_inv_colwise,
) in zip(params, scale_invs_rowwise, scale_invs_colwise):
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
if master_weight is None:
continue
# Cast master weight to FP8
end_offset = start_offset + master_weight.numel()
if use_fsdp_shard_model_weights:
rowwise_fragment = model_weight_fragment[0]
colwise_fragment = model_weight_fragment[1]
else:
rowwise_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset]
colwise_fragment = model_weight._columnwise_data.reshape(-1)[start_offset:end_offset]
assert len(model_weight.shape) == 2
h, w = model_weight.shape
tex.mxfp8_scaling_partial_cast(
master_weight,
rowwise_fragment,
colwise_fragment,
scale_inv_rowwise,
scale_inv_colwise,
h,
w,
start_offset,
)
def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Tensor]]):
"""
Post-processing after all-gather for weights in distributed optimizer.
......@@ -489,6 +621,9 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten
elif isinstance(model_weight, Float8BlockwiseQTensor):
# Blockwise scaling: create column-wise storage.
model_weight._create_columnwise()
elif isinstance(model_weight, MXFP8Tensor):
# MXFP8 scaling: no need to do anything.
pass
elif isinstance(model_weight, QuantizedTensor):
raise ValueError(f"post_processing for {type(model_weight)} is not supported")
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""PyTorch version utilities"""
from __future__ import annotations
import functools
import torch
from packaging.version import Version as PkgVersion
@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment