Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
...@@ -22,7 +22,7 @@ from .base import ( ...@@ -22,7 +22,7 @@ from .base import (
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ._common import WeightGradStore from ._common import WeightGradStore
from ..fp8 import FP8GlobalStateManager from ..quantization import FP8GlobalStateManager
from ..utils import ( from ..utils import (
divide, divide,
cast_if_needed, cast_if_needed,
...@@ -46,7 +46,7 @@ from ..cpu_offload import is_cpu_offload_enabled ...@@ -46,7 +46,7 @@ from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensorBase, QuantizedTensorStorage,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
...@@ -204,39 +204,27 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -204,39 +204,27 @@ class _GroupedLinear(torch.autograd.Function):
inputmats[0] = inp inputmats[0] = inp
else: else:
for inputmat in inputmats: for inputmat in inputmats:
if isinstance(inputmat, QuantizedTensorBase): if isinstance(inputmat, QuantizedTensorStorage):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
else: else:
inputmats = [None] * num_gemms inputmats = [None] * num_gemms
if inp.requires_grad: if inp.requires_grad:
for weight in weights_fp8: for weight in weights_fp8:
if isinstance(weight, QuantizedTensorBase): if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True) weight.update_usage(columnwise_usage=True)
for i in range(num_gemms): if cpu_offloading:
weights[i].offloading_activation = False ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad")
weights_fp8[i].offloading_activation = False
biases[i].offloading_activation = False
ctx.fine_grained_activation_offloading = fine_grained_activation_offloading
if fine_grained_activation_offloading and cpu_offloading:
raise ValueError(
f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
)
if ( if ctx.grad_added_to_main_grad:
fine_grained_activation_offloading # If you are passing torch.nn.Parameter through the Torch hooks, you will
and weights[0].requires_grad # get back torch.Tensor. Torch rips off the Parameter wrapper.
and fuse_wgrad_accumulation # You need to preserve the weight object to have all the attributes user
): # sets for the weights. Because of this, it is not recommended to offload
grad_added_to_main_grad_list = [] # weights if weights are externally touched outside this module
for weight in weights: ctx.weight_objects = []
if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"): for weight in weights:
grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad) ctx.weight_objects.append(weight)
weight.grad_added_to_main_grad = True
else:
grad_added_to_main_grad_list.append(None)
ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats, *inputmats,
...@@ -300,13 +288,15 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -300,13 +288,15 @@ class _GroupedLinear(torch.autograd.Function):
biases = saved_tensors[3 * N : 4 * N] biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
if (ctx.cpu_offloading or ctx.fine_grained_activation_offloading) and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading:
for i in range(ctx.num_gemms): if ctx.grad_added_to_main_grad:
w = torch.nn.Parameter(weights[i], weights[i].requires_grad) for i, weight in enumerate(ctx.weight_objects):
w.main_grad = main_grads[i] origin_weights[i] = ctx.weight_objects[i]
weights[i] = w ctx.weight_objects[i] = None
if ctx.fine_grained_activation_offloading and weights[0].requires_grad:
weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i] if ctx.fuse_wgrad_accumulation:
for i in range(N):
origin_weights[i].main_grad = main_grads[i]
# Preprocess grad output # Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
...@@ -369,7 +359,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -369,7 +359,7 @@ class _GroupedLinear(torch.autograd.Function):
) )
for weight, quantizer in zip(weights, ctx.weight_quantizers): for weight, quantizer in zip(weights, ctx.weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensorBase): if quantizer is not None and isinstance(weight, QuantizedTensorStorage):
weight.update_usage( weight.update_usage(
rowwise_usage=quantizer.rowwise_usage, rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage, columnwise_usage=quantizer.columnwise_usage,
...@@ -433,7 +423,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -433,7 +423,11 @@ class _GroupedLinear(torch.autograd.Function):
use_bias=ctx.use_bias if grad_biases[0] is None else None, use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases, bias=biases,
use_split_accumulator=wgrad_gemm_use_split_accumulator, use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=(
accumulate_wgrad_into_param_main_grad
if not getattr(weights[0], "overwrite_main_grad", False)
else False
),
) )
# WGRAD # WGRAD
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
...@@ -555,7 +549,9 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -555,7 +549,9 @@ class GroupedLinear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in. size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
...@@ -772,7 +768,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -772,7 +768,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced) produced)
""" """
assert not isinstance( assert not isinstance(
inp, QuantizedTensorBase inp, QuantizedTensorStorage
), "GroupedLinear doesn't support input tensor in FP8." ), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
...@@ -907,16 +903,17 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -907,16 +903,17 @@ class GroupedLinear(TransformerEngineBaseModule):
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module.""" """Get the weight tensors of the module."""
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors): if not self.fp8 and any(isinstance(w, QuantizedTensorStorage) for w in weight_tensors):
warnings.warn( warnings.warn(
"You are using quantized weights without quantized compute. " "You are using quantized weights without quantized compute. "
"Please make sure this is intentional." "Please make sure this is intentional."
) )
weight_tensors = [ weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensorBase) else w for w in weight_tensors w.dequantize() if isinstance(w, QuantizedTensorStorage) else w
for w in weight_tensors
] ]
return weight_tensors return weight_tensors
......
...@@ -16,6 +16,7 @@ import transformer_engine_torch as tex ...@@ -16,6 +16,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_experimental
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_workspace, get_workspace,
...@@ -26,9 +27,10 @@ from .base import ( ...@@ -26,9 +27,10 @@ from .base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ..fp8 import FP8GlobalStateManager from ..quantization import FP8GlobalStateManager
from ..utils import ( from ..utils import (
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
cast_if_needed, cast_if_needed,
clear_tensor_data, clear_tensor_data,
divide, divide,
...@@ -57,7 +59,7 @@ from ..graph import is_graph_capturing ...@@ -57,7 +59,7 @@ from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, WeightGradStore from ._common import apply_normalization, noop_cat, WeightGradStore
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
QuantizedTensorBase, QuantizedTensorStorage,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
...@@ -65,8 +67,8 @@ from ..tensor.quantized_tensor import ( ...@@ -65,8 +67,8 @@ from ..tensor.quantized_tensor import (
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
...@@ -144,6 +146,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -144,6 +146,8 @@ class _LayerNormLinear(torch.autograd.Function):
if ub_name is not None: if ub_name is not None:
nvtx_label = f"{nvtx_label}.{ub_name}" nvtx_label = f"{nvtx_label}.{ub_name}"
with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
out_features, in_features = weight.shape out_features, in_features = weight.shape
inp_shape = inp.shape inp_shape = inp.shape
...@@ -153,6 +157,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -153,6 +157,7 @@ class _LayerNormLinear(torch.autograd.Function):
inputmat = inp inputmat = inp
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat, weight) assert_dim_for_fp8_exec(inputmat, weight)
assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer)
# Cast for native AMP # Cast for native AMP
nvtx_range_push(f"{nvtx_label}.norm_input_cast") nvtx_range_push(f"{nvtx_label}.norm_input_cast")
...@@ -166,7 +171,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -166,7 +171,6 @@ class _LayerNormLinear(torch.autograd.Function):
weight_requires_grad = weight.requires_grad weight_requires_grad = weight.requires_grad
backward_needs_input = is_grad_enabled and weight_requires_grad backward_needs_input = is_grad_enabled and weight_requires_grad
with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Configure Userbuffers communication (comm+GEMM overlap) # Configure Userbuffers communication (comm+GEMM overlap)
if debug: # turn off userbuffers in debug mode if debug: # turn off userbuffers in debug mode
...@@ -199,11 +203,13 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -199,11 +203,13 @@ class _LayerNormLinear(torch.autograd.Function):
# Avoid quantized norm kernel if norm output will be returned # Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision. # or if a gather of ln_out must be in high precision.
experimental = is_experimental(input_quantizer)
with_quantized_norm = ( with_quantized_norm = (
fp8 fp8
and not debug and not debug
and not return_layernorm_output and not return_layernorm_output
and not return_layernorm_output_gathered and not return_layernorm_output_gathered
and not experimental # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
) )
# Apply normalization # Apply normalization
...@@ -249,7 +255,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -249,7 +255,8 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer = None quantizer = None
if fp8 or debug: if fp8 or debug:
quantizer = input_quantizer quantizer = input_quantizer
if not with_quantized_norm: # experimental recipe doesn't need to support quantized AG
if not with_quantized_norm and not experimental:
ln_out = quantizer(ln_out) ln_out = quantizer(ln_out)
quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
...@@ -280,7 +287,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -280,7 +287,7 @@ class _LayerNormLinear(torch.autograd.Function):
weightmat = weight weightmat = weight
quantized_weight = False quantized_weight = False
if fp8 or debug: if fp8 or debug:
quantized_weight = not isinstance(weight, QuantizedTensorBase) quantized_weight = not isinstance(weight, QuantizedTensorStorage)
# Configure quantizer # Configure quantizer
if weight_quantizer is not None: if weight_quantizer is not None:
...@@ -405,18 +412,18 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -405,18 +412,18 @@ class _LayerNormLinear(torch.autograd.Function):
# Input with column-wise usage is needed for wgrad GEMM. # Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input: if backward_needs_input:
if isinstance(ln_out, QuantizedTensorBase): if isinstance(ln_out, QuantizedTensorStorage):
# For sequence parallel in vanilla FP8, rowwise data is # For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data # to gather the input. For MXFP8, columnwise only data
# can be allgathered. # can be allgathered.
if ( if (
isinstance(ln_out, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) isinstance(ln_out, (MXFP8TensorStorage, Float8BlockwiseQTensorStorage))
or not ctx.ln_out_needs_gather or not ctx.ln_out_needs_gather
): ):
ln_out.update_usage(rowwise_usage=False) ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(weightmat, QuantizedTensorBase): if isinstance(weightmat, QuantizedTensorStorage):
weightmat.update_usage(columnwise_usage=True) weightmat.update_usage(columnwise_usage=True)
if cpu_offloading: if cpu_offloading:
...@@ -716,9 +723,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -716,9 +723,9 @@ class _LayerNormLinear(torch.autograd.Function):
# -------------------------------------------------- # --------------------------------------------------
# Make sure required data is available # Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase): if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(rowwise_usage=True) grad_output.update_usage(rowwise_usage=True)
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase): if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True) weight.update_usage(columnwise_usage=True)
# Choose whether to use GEMM kernel with split accumulator # Choose whether to use GEMM kernel with split accumulator
...@@ -837,14 +844,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -837,14 +844,14 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_work.wait() ln_out_total_work.wait()
ln_out_total_work = None ln_out_total_work = None
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if isinstance(ln_out_total, QuantizedTensorBase): if isinstance(ln_out_total, QuantizedTensorStorage):
ln_out_total.update_usage(columnwise_usage=True) ln_out_total.update_usage(columnwise_usage=True)
else: else:
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total) ln_out_total = ctx.input_quantizer(ln_out_total)
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase): if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(columnwise_usage=True) grad_output.update_usage(columnwise_usage=True)
else: else:
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
...@@ -880,7 +887,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -880,7 +887,11 @@ class _LayerNormLinear(torch.autograd.Function):
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
"quantization_params": ctx.grad_weight_quantizer, "quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad, "accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(weight, "overwrite_main_grad", False)
else False
),
"layout": "NT", "layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None, "out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None), "bias": (bias if (grad_bias is None and not ctx.fp8) else None),
...@@ -1037,7 +1048,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -1037,7 +1048,7 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers # Scatter fp8 weight buffers
# if ctx.fp8 and not isinstance(weight, QuantizedTensorBase): # if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage):
# _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) # _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return ( return (
...@@ -1164,7 +1175,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1164,7 +1175,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in. size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
...@@ -1470,6 +1483,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1470,6 +1483,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling(): elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)
# elif other recipes (mxfp8, etc) # elif other recipes (mxfp8, etc)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
...@@ -1812,7 +1827,29 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1812,7 +1827,29 @@ class LayerNormLinear(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_linear."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# set input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module.""" """Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names] unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
......
...@@ -18,6 +18,7 @@ import transformer_engine_torch as tex ...@@ -18,6 +18,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_experimental
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_workspace, get_workspace,
...@@ -28,7 +29,7 @@ from .base import ( ...@@ -28,7 +29,7 @@ from .base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ..fp8 import FP8GlobalStateManager from ..quantization import FP8GlobalStateManager
from ..jit import ( from ..jit import (
bias_gelu_fused, bias_gelu_fused,
bgrad_dgelu_fused, bgrad_dgelu_fused,
...@@ -41,6 +42,7 @@ from ..utils import ( ...@@ -41,6 +42,7 @@ from ..utils import (
init_method_constant, init_method_constant,
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
clear_tensor_data, clear_tensor_data,
requires_grad, requires_grad,
needs_quantized_gemm, needs_quantized_gemm,
...@@ -65,11 +67,12 @@ from ..tensor.float8_tensor import ( ...@@ -65,11 +67,12 @@ from ..tensor.float8_tensor import (
Float8Tensor, Float8Tensor,
) )
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensorBase, QuantizedTensorStorage,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
...@@ -120,8 +123,14 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): ...@@ -120,8 +123,14 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"swiglu": (tex.swiglu, tex.dswiglu, None), "swiglu": (tex.swiglu, tex.dswiglu, None),
} }
# no activation fusion written yet # no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling: [] # Per-tensor current scaling or fp8 blockwise scaling or custom quantization: []
if recipe.float8_current_scaling() or recipe.float8_block_scaling(): # TODO(ksivaman): Fuse nvfp4 act once kernel is available.
if (
recipe.float8_current_scaling()
or recipe.float8_block_scaling()
or recipe.nvfp4()
or recipe.custom()
):
return { return {
"gelu": (tex.gelu, tex.dgelu, None), "gelu": (tex.gelu, tex.dgelu, None),
"geglu": (tex.geglu, tex.dgeglu, None), "geglu": (tex.geglu, tex.dgeglu, None),
...@@ -218,6 +227,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -218,6 +227,7 @@ class _LayerNormMLP(torch.autograd.Function):
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight)
assert_dim_for_all_gather(inputmat, sequence_parallel, fc1_input_quantizer)
activation_func = _act_func( activation_func = _act_func(
activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
...@@ -265,11 +275,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -265,11 +275,13 @@ class _LayerNormMLP(torch.autograd.Function):
# high precision layernorm output and output of the linear are returned # high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm # for debug: : layernorm output = High precision to enable processing of this norm
experimental = is_experimental(fc1_input_quantizer)
with_quantized_norm = ( with_quantized_norm = (
fp8 fp8
and not debug and not debug
and not return_layernorm_output and not return_layernorm_output
and not return_layernorm_output_gathered and not return_layernorm_output_gathered
and not experimental
) )
# Apply normalization # Apply normalization
...@@ -309,7 +321,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -309,7 +321,8 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer = None quantizer = None
if fp8 or debug: if fp8 or debug:
quantizer = fc1_input_quantizer quantizer = fc1_input_quantizer
if not with_quantized_norm: # experimental recipe doesn't need to support quantized AG
if not with_quantized_norm and not experimental:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag: if ub_overlap_ag:
...@@ -447,10 +460,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -447,10 +460,18 @@ class _LayerNormMLP(torch.autograd.Function):
act_out = fc2_input_quantizer(act_out) act_out = fc2_input_quantizer(act_out)
else: else:
fc1_out, *_ = fc1_outputs fc1_out, *_ = fc1_outputs
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): if fp8:
# tex.quantize does not support GELU fusion for blockwise. recipe = FP8GlobalStateManager.get_fp8_recipe()
act_out = activation_func(fc1_out, None) if recipe.float8_block_scaling():
act_out = tex.quantize(act_out, fc2_input_quantizer) # tex.quantize does not support GELU fusion for blockwise
act_out = activation_func(fc1_out, None)
act_out = tex.quantize(act_out, fc2_input_quantizer)
elif recipe.custom():
# tex.quantize does not support custom quantizers
act_out = activation_func(fc1_out, None)
act_out = fc2_input_quantizer(act_out)
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)
else: else:
if fp8_calibration: if fp8_calibration:
act_out = activation_func(fc1_out, None) act_out = activation_func(fc1_out, None)
...@@ -521,9 +542,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -521,9 +542,9 @@ class _LayerNormMLP(torch.autograd.Function):
if is_grad_enabled: if is_grad_enabled:
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(fc1_weight_final, QuantizedTensorBase): if isinstance(fc1_weight_final, QuantizedTensorStorage):
fc1_weight_final.update_usage(columnwise_usage=True) fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensorBase): if isinstance(fc2_weight_final, QuantizedTensorStorage):
fc2_weight_final.update_usage(columnwise_usage=True) fc2_weight_final.update_usage(columnwise_usage=True)
if cpu_offloading: if cpu_offloading:
...@@ -555,6 +576,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -555,6 +576,7 @@ class _LayerNormMLP(torch.autograd.Function):
if not fc2_weight.requires_grad: if not fc2_weight.requires_grad:
clear_tensor_data(act_out) clear_tensor_data(act_out)
act_out = None act_out = None
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -680,6 +702,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -680,6 +702,7 @@ class _LayerNormMLP(torch.autograd.Function):
mu, mu,
rsigma, rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors) ) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed # Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors. # by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None ctx.tensor_objects = None
...@@ -820,10 +843,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -820,10 +843,10 @@ class _LayerNormMLP(torch.autograd.Function):
) )
# Make sure required data is available # Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase): if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(rowwise_usage=True) grad_output.update_usage(rowwise_usage=True)
if ctx.fc2_weight_quantizer is not None and isinstance( if ctx.fc2_weight_quantizer is not None and isinstance(
ctx.fc2_weight, QuantizedTensorBase ctx.fc2_weight, QuantizedTensorStorage
): ):
ctx.fc2_weight.update_usage(columnwise_usage=True) ctx.fc2_weight.update_usage(columnwise_usage=True)
...@@ -905,14 +928,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -905,14 +928,14 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and # Note: Synchronize tensor-parallel communication and
# make sure required data is available # make sure required data is available
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if isinstance(act_out, QuantizedTensorBase): if isinstance(act_out, QuantizedTensorStorage):
act_out.update_usage(columnwise_usage=True) act_out.update_usage(columnwise_usage=True)
else: else:
ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True)
act_out = ctx.fc2_input_quantizer(act_out) act_out = ctx.fc2_input_quantizer(act_out)
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase): if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(columnwise_usage=True) grad_output.update_usage(columnwise_usage=True)
else: else:
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
...@@ -932,7 +955,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -932,7 +955,11 @@ class _LayerNormMLP(torch.autograd.Function):
else ctx.activation_dtype else ctx.activation_dtype
), ),
"quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision
"accumulate": accumulate_wgrad_into_param_main_grad, "accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(fc1_weight, "overwrite_main_grad", False)
else False
),
"layout": "NT", "layout": "NT",
"out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, "out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, "bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
...@@ -1028,8 +1055,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1028,8 +1055,11 @@ class _LayerNormMLP(torch.autograd.Function):
) # activation in high precision ) # activation in high precision
if ctx.fp8: if ctx.fp8:
# TODO float8 blockwise current scaling has no bgrad fusion for now # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now
if isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer): if (
isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer)
or ctx.fp8_recipe.custom()
):
fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0)
dact = ctx.fc1_grad_output_quantizer(dact) dact = ctx.fc1_grad_output_quantizer(dact)
else: else:
...@@ -1074,7 +1104,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1074,7 +1104,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Make sure required data is available # Make sure required data is available
if ctx.fc1_weight_quantizer is not None and isinstance( if ctx.fc1_weight_quantizer is not None and isinstance(
ctx.fc1_weight_quantizer, QuantizedTensorBase ctx.fc1_weight_quantizer, QuantizedTensorStorage
): ):
ctx.fc1_weight.update_usage(columnwise_usage=True) ctx.fc1_weight.update_usage(columnwise_usage=True)
...@@ -1145,7 +1175,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1145,7 +1175,7 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total_work.wait() ln_out_total_work.wait()
ln_out_total_work = None ln_out_total_work = None
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if isinstance(ln_out_total, QuantizedTensorBase): if isinstance(ln_out_total, QuantizedTensorStorage):
ln_out_total.update_usage(columnwise_usage=True) ln_out_total.update_usage(columnwise_usage=True)
else: else:
ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True)
...@@ -1155,7 +1185,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1155,7 +1185,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and # Note: Synchronize tensor-parallel communication and
# make sure required data is available # make sure required data is available
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if isinstance(dact, QuantizedTensorBase): if isinstance(dact, QuantizedTensorStorage):
dact.update_usage(columnwise_usage=True) dact.update_usage(columnwise_usage=True)
else: else:
ctx.fc1_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.fc1_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
...@@ -1178,7 +1208,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1178,7 +1208,11 @@ class _LayerNormMLP(torch.autograd.Function):
else ctx.activation_dtype else ctx.activation_dtype
), ),
"quantization_params": ctx.fc1_grad_weight_quantizer, "quantization_params": ctx.fc1_grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad, "accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(fc2_weight, "overwrite_main_grad", False)
else False
),
"layout": "NT", "layout": "NT",
"out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, "out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, "bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
...@@ -1486,7 +1520,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1486,7 +1520,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in. size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias for FC2, but when set to `True`, this module will not apply the additive bias for FC2, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
...@@ -1718,6 +1754,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1718,6 +1754,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling(): elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)
# elif for other recipes (mxfp8, etc.) # elif for other recipes (mxfp8, etc.)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
...@@ -1937,7 +1975,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1937,7 +1975,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
fc2_input_quantizer.set_usage( fc2_input_quantizer.set_usage(
rowwise=True, rowwise=True,
columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), columnwise=isinstance(
fc2_input_quantizer,
(MXFP8Quantizer, Float8BlockQuantizer, NVFP4Quantizer),
),
) )
fc1_input_quantizer.internal = True fc1_input_quantizer.internal = True
if fp8_output: if fp8_output:
...@@ -2142,7 +2183,29 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2142,7 +2183,29 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT2 tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_mlp."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.set_parallel_mode:
# fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.set_parallel_mode:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module.""" """Get the weight tensors of the module."""
return [self.fc1_weight, self.fc2_weight] return [self.fc1_weight, self.fc2_weight]
......
...@@ -27,7 +27,7 @@ from .base import ( ...@@ -27,7 +27,7 @@ from .base import (
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ._common import noop_cat, WeightGradStore from ._common import noop_cat, WeightGradStore
from ..fp8 import FP8GlobalStateManager from ..quantization import FP8GlobalStateManager
from ..utils import ( from ..utils import (
cast_if_needed, cast_if_needed,
clear_tensor_data, clear_tensor_data,
...@@ -36,6 +36,7 @@ from ..utils import ( ...@@ -36,6 +36,7 @@ from ..utils import (
requires_grad, requires_grad,
needs_quantized_gemm, needs_quantized_gemm,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
nvtx_range_pop, nvtx_range_pop,
nvtx_range_push, nvtx_range_push,
get_activation_offloading, get_activation_offloading,
...@@ -60,13 +61,14 @@ from ..jit import no_torch_dynamo ...@@ -60,13 +61,14 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
QuantizedTensorBase, QuantizedTensorStorage,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.utils import is_experimental
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
...@@ -154,6 +156,9 @@ class _Linear(torch.autograd.Function): ...@@ -154,6 +156,9 @@ class _Linear(torch.autograd.Function):
ub_obj = get_ub(ub_name + "_fprop", fp8) ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.AG ub_type = tex.CommOverlapType.AG
# experimental recipe check
experimental = is_experimental(input_quantizer) or is_experimental(weight_quantizer)
# ------------------------------------------------------ # ------------------------------------------------------
# Prepare input tensor # Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
...@@ -164,6 +169,7 @@ class _Linear(torch.autograd.Function): ...@@ -164,6 +169,7 @@ class _Linear(torch.autograd.Function):
own_quantized_input = False own_quantized_input = False
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat, weight) assert_dim_for_fp8_exec(inputmat, weight)
assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer)
if save_original_input: if save_original_input:
assert not isinstance( assert not isinstance(
input_quantizer, Float8Quantizer input_quantizer, Float8Quantizer
...@@ -175,7 +181,7 @@ class _Linear(torch.autograd.Function): ...@@ -175,7 +181,7 @@ class _Linear(torch.autograd.Function):
if fp8 or debug: if fp8 or debug:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
if not isinstance(inputmat, QuantizedTensorBase): if not isinstance(inputmat, QuantizedTensorStorage) and not experimental:
own_quantized_input = True own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance( if isinstance(
...@@ -213,7 +219,7 @@ class _Linear(torch.autograd.Function): ...@@ -213,7 +219,7 @@ class _Linear(torch.autograd.Function):
else: # Do not all-gather input tensor else: # Do not all-gather input tensor
if fp8 or debug: if fp8 or debug:
if isinstance(inputmat, QuantizedTensorBase): if isinstance(inputmat, QuantizedTensorStorage):
inputmat.update_usage(rowwise_usage=True) inputmat.update_usage(rowwise_usage=True)
else: else:
if input_quantizer is None: if input_quantizer is None:
...@@ -369,7 +375,7 @@ class _Linear(torch.autograd.Function): ...@@ -369,7 +375,7 @@ class _Linear(torch.autograd.Function):
if ( if (
backward_needs_input backward_needs_input
and own_quantized_input and own_quantized_input
and isinstance(inputmat, QuantizedTensorBase) and isinstance(inputmat, QuantizedTensorStorage)
): ):
if ( if (
ctx.backward_input_needs_gather ctx.backward_input_needs_gather
...@@ -388,7 +394,7 @@ class _Linear(torch.autograd.Function): ...@@ -388,7 +394,7 @@ class _Linear(torch.autograd.Function):
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad: if inp.requires_grad:
if isinstance(weightmat, QuantizedTensorBase): if isinstance(weightmat, QuantizedTensorStorage):
weightmat.update_usage(columnwise_usage=True) weightmat.update_usage(columnwise_usage=True)
if cpu_offloading and saved_inputmat is not None: if cpu_offloading and saved_inputmat is not None:
...@@ -401,7 +407,7 @@ class _Linear(torch.autograd.Function): ...@@ -401,7 +407,7 @@ class _Linear(torch.autograd.Function):
ctx.fsdp_shapes = _fsdp_scatter_tensors( ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group, fsdp_group,
saved_inputmat, saved_inputmat,
weightmat if fp8 and not isinstance(weight, QuantizedTensorBase) else None, weightmat if fp8 and not isinstance(weight, QuantizedTensorStorage) else None,
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
...@@ -471,6 +477,7 @@ class _Linear(torch.autograd.Function): ...@@ -471,6 +477,7 @@ class _Linear(torch.autograd.Function):
ctx.main_grad_func = lambda: weight.main_grad ctx.main_grad_func = lambda: weight.main_grad
ctx.debug = debug ctx.debug = debug
ctx.experimental = experimental
ctx.cpu_offloading = cpu_offloading ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
...@@ -637,10 +644,10 @@ class _Linear(torch.autograd.Function): ...@@ -637,10 +644,10 @@ class _Linear(torch.autograd.Function):
inputmat_total_work = None inputmat_total_work = None
if ctx.requires_wgrad: if ctx.requires_wgrad:
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if isinstance(inputmat, QuantizedTensorBase): if isinstance(inputmat, QuantizedTensorStorage):
# Input tensor is already quantized # Input tensor is already quantized
pass pass
elif ctx.debug: elif ctx.debug or ctx.experimental:
# Debug quantizer will be applied immediately before wgrad GEMM # Debug quantizer will be applied immediately before wgrad GEMM
pass pass
else: else:
...@@ -656,7 +663,7 @@ class _Linear(torch.autograd.Function): ...@@ -656,7 +663,7 @@ class _Linear(torch.autograd.Function):
quantizer.set_usage(rowwise=False, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
inputmat = quantizer(inputmat) inputmat = quantizer(inputmat)
else: else:
if isinstance(inputmat, QuantizedTensorBase): if isinstance(inputmat, QuantizedTensorStorage):
inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) inputmat = inputmat.dequantize(dtype=ctx.activation_dtype)
else: else:
inputmat = cast_if_needed(inputmat, ctx.activation_dtype) inputmat = cast_if_needed(inputmat, ctx.activation_dtype)
...@@ -701,9 +708,11 @@ class _Linear(torch.autograd.Function): ...@@ -701,9 +708,11 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
# Make sure required data is available # Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase): if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(rowwise_usage=True) grad_output.update_usage(rowwise_usage=True)
if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase): if ctx.weight_quantizer is not None and isinstance(
weight_fp8, QuantizedTensorStorage
):
weight_fp8.update_usage(columnwise_usage=True) weight_fp8.update_usage(columnwise_usage=True)
# Choose whether to use GEMM kernel with split accumulator # Choose whether to use GEMM kernel with split accumulator
...@@ -729,6 +738,7 @@ class _Linear(torch.autograd.Function): ...@@ -729,6 +738,7 @@ class _Linear(torch.autograd.Function):
# dgrad GEMM # dgrad GEMM
# Note: dx = dy * w # Note: dx = dy * w
nvtx_range_push(f"{nvtx_label}.dgrad_gemm") nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
weight_fp8, weight_fp8,
...@@ -786,7 +796,7 @@ class _Linear(torch.autograd.Function): ...@@ -786,7 +796,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work.wait() inputmat_total_work.wait()
inputmat_total_work = None inputmat_total_work = None
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if isinstance(inputmat_total, QuantizedTensorBase): if isinstance(inputmat_total, QuantizedTensorStorage):
inputmat_total.update_usage(columnwise_usage=True) inputmat_total.update_usage(columnwise_usage=True)
else: else:
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
...@@ -828,7 +838,7 @@ class _Linear(torch.autograd.Function): ...@@ -828,7 +838,7 @@ class _Linear(torch.autograd.Function):
) )
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase): if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(columnwise_usage=True) grad_output.update_usage(columnwise_usage=True)
else: else:
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
...@@ -864,7 +874,11 @@ class _Linear(torch.autograd.Function): ...@@ -864,7 +874,11 @@ class _Linear(torch.autograd.Function):
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
"quantization_params": ctx.grad_weight_quantizer, "quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad, "accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(weight, "overwrite_main_grad", False)
else False
),
"layout": "NT", "layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None, "out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None), "bias": (bias if (grad_bias is None and not ctx.fp8) else None),
...@@ -984,7 +998,7 @@ class _Linear(torch.autograd.Function): ...@@ -984,7 +998,7 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers # Scatter fp8 weight buffers
if ctx.fp8 and not isinstance(weight, QuantizedTensorBase): if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage):
_fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return ( return (
wgrad, wgrad,
...@@ -1086,7 +1100,9 @@ class Linear(TransformerEngineBaseModule): ...@@ -1086,7 +1100,9 @@ class Linear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in. size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
...@@ -1363,6 +1379,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -1363,6 +1379,8 @@ class Linear(TransformerEngineBaseModule):
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling(): elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)
# elif for other recipes (mxfp8, etc.) # elif for other recipes (mxfp8, etc.)
def reset_parameters(self, defer_init=False): def reset_parameters(self, defer_init=False):
...@@ -1452,7 +1470,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -1452,7 +1470,6 @@ class Linear(TransformerEngineBaseModule):
if not debug if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad) else self._get_debug_quantizers(fp8_output, fp8_grad)
) )
if debug: if debug:
if self.no_debug_features_active(quantizers): if self.no_debug_features_active(quantizers):
debug = False debug = False
...@@ -1557,7 +1574,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1557,7 +1574,7 @@ class Linear(TransformerEngineBaseModule):
for name, q in zip(names, original_quantizers) for name, q in zip(names, original_quantizers)
) )
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module.""" """Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names] unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
...@@ -1693,6 +1710,28 @@ class Linear(TransformerEngineBaseModule): ...@@ -1693,6 +1710,28 @@ class Linear(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# customize input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_quantizers(self) -> List[Quantizer]: def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module.""" """Get the weight quantizers of the module."""
if not self.fp8 and not self.fp8_calibration: if not self.fp8 and not self.fp8_calibration:
......
...@@ -11,19 +11,19 @@ import torch ...@@ -11,19 +11,19 @@ import torch
from transformer_engine_torch import FP8TensorMeta from transformer_engine_torch import FP8TensorMeta
from .. import torch_version from .. import torch_version
from ..fp8 import FP8GlobalStateManager from ..quantization import FP8GlobalStateManager
from ..tensor.float8_tensor import Float8Tensor from ..tensor.float8_tensor import Float8Tensor
from ..tensor.quantized_tensor import QuantizedTensorBase from ..tensor.quantized_tensor import QuantizedTensorStorage
from ..utils import canonicalize_dtype from ..utils import canonicalize_dtype
def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorBase) -> bool: def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool:
"""Check if tensor is a quantized tensor""" """Check if tensor is a quantized tensor"""
return isinstance(tensor, QuantizedTensorBase) return isinstance(tensor, QuantizedTensorStorage)
def maybe_dequantize( def maybe_dequantize(
tensor: torch.Tensor | QuantizedTensorBase, dtype: torch.dtype | None = None tensor: torch.Tensor | QuantizedTensorStorage, dtype: torch.dtype | None = None
) -> torch.Tensor: ) -> torch.Tensor:
"""Dequantize tensor to given dtype or just convert if not a quantized tensor""" """Dequantize tensor to given dtype or just convert if not a quantized tensor"""
if is_quantized_tensor(tensor): if is_quantized_tensor(tensor):
......
...@@ -4,7 +4,19 @@ ...@@ -4,7 +4,19 @@
"""Single tensor operations supported by the operation fuser.""" """Single tensor operations supported by the operation fuser."""
from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU from .activation import (
GELU,
GEGLU,
QGELU,
QGEGLU,
ReLU,
ReGLU,
SReLU,
SReGLU,
SiLU,
SwiGLU,
ClampedSwiGLU,
)
from .add_extra_input import AddExtraInput from .add_extra_input import AddExtraInput
from .all_gather import AllGather from .all_gather import AllGather
from .all_reduce import AllReduce from .all_reduce import AllReduce
......
...@@ -28,6 +28,7 @@ __all__ = [ ...@@ -28,6 +28,7 @@ __all__ = [
"SReGLU", "SReGLU",
"SiLU", "SiLU",
"SwiGLU", "SwiGLU",
"ClampedSwiGLU",
] ]
...@@ -392,3 +393,38 @@ class SwiGLU(_ActivationOperation): ...@@ -392,3 +393,38 @@ class SwiGLU(_ActivationOperation):
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dswiglu(*args, **kwargs) return tex.dswiglu(*args, **kwargs)
class ClampedSwiGLU(_ActivationOperation):
r"""GPT-OSS
Implementation based on `GPT-OSS<https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250>`__.
This activation has two differences compared to the original SwiGLU
1. Both gate and pre-activations are clipped based on parameter limit.
2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation.
.. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt
from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor.
Parameters
----------
limit: float
The clamp limit.
alpha: float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input: bool, default = False
Quantize input tensor when caching for use in the backward pass.
"""
def __init__(
self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False
):
super().__init__(cache_quantized_input=cache_quantized_input)
self.limit = limit
self.alpha = alpha
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs)
...@@ -19,7 +19,7 @@ from ...distributed import ( ...@@ -19,7 +19,7 @@ from ...distributed import (
gather_along_first_dim, gather_along_first_dim,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
) )
from ...fp8 import FP8GlobalStateManager, Recipe from ...quantization import FP8GlobalStateManager, Recipe
from ...module.base import ( from ...module.base import (
_2X_ACC_FPROP, _2X_ACC_FPROP,
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
...@@ -29,7 +29,7 @@ from ...module.base import ( ...@@ -29,7 +29,7 @@ from ...module.base import (
) )
from ...tensor import Quantizer from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from ...utils import ( from ...utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
...@@ -80,7 +80,9 @@ class BasicLinear(BasicOperation): ...@@ -80,7 +80,9 @@ class BasicLinear(BasicOperation):
autograd. The weight's `main_grad` must be set externally and autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be there is no guarantee that `grad` will be set or be
meaningful. This is primarily intented to integrate with meaningful. This is primarily intented to integrate with
Megatron-LM. Megatron-LM. This argument along with weight tensor having
attribute 'overwrite_main_grad' set to True will overwrite
`main_grad` instead of accumulating.
userbuffers_options, dict, optional userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly compute using Userbuffers. This feature is highly
...@@ -301,8 +303,8 @@ class BasicLinear(BasicOperation): ...@@ -301,8 +303,8 @@ class BasicLinear(BasicOperation):
"Tried to quantize weight with deferred initialization " "Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. " "due to meta device, but no quantizer was available. "
"This is most likely because the weight was initialized " "This is most likely because the weight was initialized "
"within fp8_model_init, but the forward pass was not " "within quantized_model_init, but the forward pass was not "
"performed within fp8_autocast." "performed within autocast."
) )
quantizer.set_usage( quantizer.set_usage(
rowwise=True, rowwise=True,
...@@ -322,6 +324,20 @@ class BasicLinear(BasicOperation): ...@@ -322,6 +324,20 @@ class BasicLinear(BasicOperation):
if self.weight.device.type == "meta": if self.weight.device.type == "meta":
self.reset_parameters() self.reset_parameters()
def pre_fuser_forward(self, *, requires_grad: bool) -> None:
super().pre_fuser_forward(requires_grad=requires_grad)
if FP8GlobalStateManager.is_fp8_enabled():
# Configure quantizer usages
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
weight_requires_grad = requires_grad and self.weight.requires_grad
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
grad_output_quantizer = self.get_quantizer("backward", 0)
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_state(recipe=recipe) super().reset_recipe_state(recipe=recipe)
...@@ -352,6 +368,35 @@ class BasicLinear(BasicOperation): ...@@ -352,6 +368,35 @@ class BasicLinear(BasicOperation):
and not getattr(self, "_with_quantized_weight", False) and not getattr(self, "_with_quantized_weight", False)
) )
# Recipe-specific configuration
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
if recipe is not None:
if recipe.float8_current_scaling():
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon
if getattr(self, "sequence_parallel", False):
tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None)
if tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
elif tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
if recipe.nvfp4():
if getattr(self, "sequence_parallel", False):
tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None)
if tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
elif tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
@staticmethod @staticmethod
def _functional_forward( def _functional_forward(
input: torch.Tensor, # pylint: disable=redefined-builtin input: torch.Tensor, # pylint: disable=redefined-builtin
...@@ -568,7 +613,7 @@ class BasicLinear(BasicOperation): ...@@ -568,7 +613,7 @@ class BasicLinear(BasicOperation):
# Prepare input tensor for backward pass # Prepare input tensor for backward pass
if weight_requires_grad: if weight_requires_grad:
if with_quantized_compute and is_quantized_tensor(x_local): if with_quantized_compute and is_quantized_tensor(x_local):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather): if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather):
# FP8 does not support all-gather of transpose data # FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True) x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else: else:
...@@ -731,7 +776,7 @@ class BasicLinear(BasicOperation): ...@@ -731,7 +776,7 @@ class BasicLinear(BasicOperation):
if with_quantized_compute: if with_quantized_compute:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(columnwise=True) input_quantizer.set_usage(rowwise=False, columnwise=True)
if with_x_all_gather: if with_x_all_gather:
x, x_async = gather_along_first_dim( x, x_async = gather_along_first_dim(
x_local, x_local,
...@@ -912,34 +957,13 @@ class BasicLinear(BasicOperation): ...@@ -912,34 +957,13 @@ class BasicLinear(BasicOperation):
input_requires_grad = ctx.requires_grad input_requires_grad = ctx.requires_grad
weight_requires_grad = ctx.requires_grad and self.weight.requires_grad weight_requires_grad = ctx.requires_grad and self.weight.requires_grad
# FP8 metadata # Quantizers
input_quantizer = self.get_quantizer("forward", 0) input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1) weight_quantizer = self.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer output_quantizer = next_op_input_quantizer
grad_output_quantizer = self.get_quantizer("backward", 0) grad_output_quantizer = self.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
# Configure quantizers
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False)
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Get autocast dtype if needed # Get autocast dtype if needed
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
...@@ -997,6 +1021,7 @@ class BasicLinear(BasicOperation): ...@@ -997,6 +1021,7 @@ class BasicLinear(BasicOperation):
weight_param = self.weight weight_param = self.weight
if hasattr(weight_param, "__fsdp_param__"): if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad() weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
......
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Quantizer from ...tensor import Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_autocast_dtype, maybe_dequantize from .._common import maybe_autocast_dtype, maybe_dequantize
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
...@@ -56,7 +56,7 @@ class Dropout(BasicOperation): ...@@ -56,7 +56,7 @@ class Dropout(BasicOperation):
out = input_ out = input_
elif impl == "fused": elif impl == "fused":
x = input_ x = input_
if not isinstance(x, Float8TensorBase): if not isinstance(x, Float8TensorStorage):
x = maybe_dequantize(x, dtype=dtype) x = maybe_dequantize(x, dtype=dtype)
out, mask = tex.dropout_fwd(x, self.dropout_probability) out, mask = tex.dropout_fwd(x, self.dropout_probability)
elif impl == "unfused": elif impl == "unfused":
......
...@@ -9,7 +9,7 @@ from typing import Optional ...@@ -9,7 +9,7 @@ from typing import Optional
import torch import torch
from ...fp8 import FP8GlobalStateManager from ...quantization import FP8GlobalStateManager
from .._common import is_quantized_tensor from .._common import is_quantized_tensor
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from ...tensor import Quantizer from ...tensor import Quantizer
...@@ -18,8 +18,8 @@ from ...tensor import Quantizer ...@@ -18,8 +18,8 @@ from ...tensor import Quantizer
class Quantize(BasicOperation): class Quantize(BasicOperation):
"""Quantize tensor data """Quantize tensor data
Uses FP8 recipe from `fp8_autocast` context. When called outside Uses recipe from `autocast` context. When called outside
of an `fp8_autocast` context, this is an identity operation. of an `autocast` context, this is an identity operation.
Parameters Parameters
---------- ----------
......
...@@ -10,7 +10,7 @@ from typing import Optional ...@@ -10,7 +10,7 @@ from typing import Optional
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import Recipe from transformer_engine.pytorch.quantization import Recipe
from transformer_engine.pytorch.ops.basic import Bias from transformer_engine.pytorch.ops.basic import Bias
from transformer_engine.pytorch.ops.basic.activation import ( from transformer_engine.pytorch.ops.basic.activation import (
_ActivationOperation, _ActivationOperation,
......
...@@ -59,6 +59,7 @@ class BackwardLinearAdd(FusedOperation): ...@@ -59,6 +59,7 @@ class BackwardLinearAdd(FusedOperation):
weight_param = linear_op.weight weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"): if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad() weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
......
...@@ -60,6 +60,7 @@ class BackwardLinearScale(FusedOperation): ...@@ -60,6 +60,7 @@ class BackwardLinearScale(FusedOperation):
weight_param = linear_op.weight weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"): if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad() weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
......
...@@ -11,7 +11,7 @@ from typing import Any, Optional ...@@ -11,7 +11,7 @@ from typing import Any, Optional
import torch import torch
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...fp8 import FP8GlobalStateManager from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer from ...tensor import Quantizer
from ..basic import BasicLinear, Bias from ..basic import BasicLinear, Bias
from ..op import FusedOperation, FusibleOperation, OperationContext from ..op import FusedOperation, FusibleOperation, OperationContext
...@@ -85,7 +85,7 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -85,7 +85,7 @@ class ForwardLinearBiasActivation(FusedOperation):
input_requires_grad = linear_op_ctx.requires_grad input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # Quantizers
input_quantizer = linear_op.get_quantizer("forward", 0) input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1) weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer output_quantizer = next_op_input_quantizer
......
...@@ -11,7 +11,7 @@ from typing import Any, Optional ...@@ -11,7 +11,7 @@ from typing import Any, Optional
import torch import torch
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...fp8 import FP8GlobalStateManager from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, Bias from ..basic import AddExtraInput, BasicLinear, Bias
from ..op import FusedOperation, FusibleOperation, OperationContext from ..op import FusedOperation, FusibleOperation, OperationContext
...@@ -79,7 +79,7 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -79,7 +79,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input_requires_grad = linear_op_ctx.requires_grad input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # Quantizers
input_quantizer = linear_op.get_quantizer("forward", 0) input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1) weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = None output_quantizer = None
......
...@@ -11,7 +11,7 @@ from typing import Any, Optional ...@@ -11,7 +11,7 @@ from typing import Any, Optional
import torch import torch
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...fp8 import FP8GlobalStateManager from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, ConstantScale from ..basic import AddExtraInput, BasicLinear, ConstantScale
from ..op import ( from ..op import (
...@@ -58,7 +58,7 @@ class ForwardLinearScaleAdd(FusedOperation): ...@@ -58,7 +58,7 @@ class ForwardLinearScaleAdd(FusedOperation):
input_requires_grad = linear_op_ctx.requires_grad input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # Quantizers
input_quantizer = linear_op.get_quantizer("forward", 0) input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1) weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = None output_quantizer = None
......
...@@ -523,6 +523,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -523,6 +523,7 @@ class UserbuffersBackwardLinear(FusedOperation):
weight_param = linear_op.weight weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"): if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad() weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
......
...@@ -14,7 +14,7 @@ from transformer_engine_torch import CommOverlapType ...@@ -14,7 +14,7 @@ from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import get_distributed_world_size from ...distributed import get_distributed_world_size
from ...fp8 import FP8GlobalStateManager from ...quantization import FP8GlobalStateManager
from ...module.base import ( from ...module.base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_ub, get_ub,
...@@ -23,7 +23,7 @@ from ...module.base import ( ...@@ -23,7 +23,7 @@ from ...module.base import (
) )
from ...tensor.quantized_tensor import Quantizer from ...tensor.quantized_tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_dequantize, is_quantized_tensor from .._common import maybe_dequantize, is_quantized_tensor
from ..basic import BasicLinear, Bias, ReduceScatter from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import ( from ..op import (
...@@ -267,7 +267,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -267,7 +267,7 @@ class UserbuffersForwardLinear(FusedOperation):
# Prepare input tensor for backward pass # Prepare input tensor for backward pass
if weight_requires_grad: if weight_requires_grad:
if with_quantized_compute and is_quantized_tensor(x_local): if with_quantized_compute and is_quantized_tensor(x_local):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather): if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data # FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True) x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else: else:
......
...@@ -11,7 +11,7 @@ import itertools ...@@ -11,7 +11,7 @@ import itertools
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe, DelayedScaling from transformer_engine.pytorch.quantization import FP8GlobalStateManager, Recipe, DelayedScaling
from transformer_engine.pytorch.ops.op import ( from transformer_engine.pytorch.ops.op import (
BasicOperation, BasicOperation,
FusibleOperation, FusibleOperation,
...@@ -472,6 +472,10 @@ class OperationFuser: ...@@ -472,6 +472,10 @@ class OperationFuser:
# Attempt to fuse operations if neccesary # Attempt to fuse operations if neccesary
self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs) self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs)
# Initialization before forward
for idx, op in enumerate(self._basic_ops):
op.pre_fuser_forward(requires_grad=idx >= self.first_op_requiring_backward)
# Fuser forward pass # Fuser forward pass
if is_grad_enabled: if is_grad_enabled:
forward_func = _OperationFuserAutogradFunction.apply forward_func = _OperationFuserAutogradFunction.apply
......
...@@ -14,10 +14,10 @@ from typing import Any, Optional ...@@ -14,10 +14,10 @@ from typing import Any, Optional
import torch import torch
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from ..fp8 import ( from ..quantization import (
FP8GlobalStateManager, FP8GlobalStateManager,
RecipeState, RecipeState,
fp8_autocast, autocast,
) )
from ..tensor import Quantizer from ..tensor import Quantizer
...@@ -65,6 +65,13 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -65,6 +65,13 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def pre_first_fuser_forward(self) -> None: def pre_first_fuser_forward(self) -> None:
"""Preprocessing before first fuser forward pass""" """Preprocessing before first fuser forward pass"""
def pre_fuser_forward(
self,
*,
requires_grad: bool, # pylint: disable=unused-argument
) -> None:
"""Preprocessing before fuser forward pass"""
def get_input_quantizer(self) -> Optional[Quantizer]: def get_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input tensor""" """Get builder class for quantized input tensor"""
...@@ -588,6 +595,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -588,6 +595,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
extra[key] = val extra[key] = val
state[mode]["extra_fp8_variables"] = extra state[mode]["extra_fp8_variables"] = extra
if not state:
return torch.empty(0, dtype=torch.uint8)
# Serialize state into byte tensor # Serialize state into byte tensor
torch.cuda.synchronize() torch.cuda.synchronize()
state_serialized = bytearray(pickle.dumps(state)) state_serialized = bytearray(pickle.dumps(state))
...@@ -624,7 +634,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -624,7 +634,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Get op's quantizer state, initializing if needed # Get op's quantizer state, initializing if needed
if self._fp8_metas is None or self._fp8_metas[mode] is None: if self._fp8_metas is None or self._fp8_metas[mode] is None:
with fp8_autocast(fp8_recipe=state[mode]["recipe"]): with autocast(recipe=state[mode]["recipe"]):
self.reset_recipe_state(recipe=state[mode]["recipe"]) self.reset_recipe_state(recipe=state[mode]["recipe"])
fp8_meta = self._fp8_metas[mode] fp8_meta = self._fp8_metas[mode]
...@@ -710,6 +720,10 @@ class FusedOperation(FusibleOperation): ...@@ -710,6 +720,10 @@ class FusedOperation(FusibleOperation):
for op in self.basic_ops: for op in self.basic_ops:
op.pre_first_fuser_forward() op.pre_first_fuser_forward()
def pre_fuser_forward(self, *, requires_grad: bool) -> None:
for op in self.basic_ops:
op.pre_fuser_forward(requires_grad=requires_grad)
def forward( def forward(
self, self,
input: torch.Tensor, # pylint: disable=redefined-builtin input: torch.Tensor, # pylint: disable=redefined-builtin
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment