Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
......@@ -22,7 +22,7 @@ from .base import (
_2X_ACC_WGRAD,
)
from ._common import WeightGradStore
from ..fp8 import FP8GlobalStateManager
from ..quantization import FP8GlobalStateManager
from ..utils import (
divide,
cast_if_needed,
......@@ -46,7 +46,7 @@ from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.quantized_tensor import (
QuantizedTensorBase,
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
restore_from_saved,
......@@ -204,39 +204,27 @@ class _GroupedLinear(torch.autograd.Function):
inputmats[0] = inp
else:
for inputmat in inputmats:
if isinstance(inputmat, QuantizedTensorBase):
if isinstance(inputmat, QuantizedTensorStorage):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
inputmats = [None] * num_gemms
if inp.requires_grad:
for weight in weights_fp8:
if isinstance(weight, QuantizedTensorBase):
if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True)
for i in range(num_gemms):
weights[i].offloading_activation = False
weights_fp8[i].offloading_activation = False
biases[i].offloading_activation = False
ctx.fine_grained_activation_offloading = fine_grained_activation_offloading
if fine_grained_activation_offloading and cpu_offloading:
raise ValueError(
f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
)
if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad")
if (
fine_grained_activation_offloading
and weights[0].requires_grad
and fuse_wgrad_accumulation
):
grad_added_to_main_grad_list = []
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_objects = []
for weight in weights:
if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"):
grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad)
weight.grad_added_to_main_grad = True
else:
grad_added_to_main_grad_list.append(None)
ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list
ctx.weight_objects.append(weight)
tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats,
......@@ -300,13 +288,15 @@ class _GroupedLinear(torch.autograd.Function):
biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
if (ctx.cpu_offloading or ctx.fine_grained_activation_offloading) and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i]
weights[i] = w
if ctx.fine_grained_activation_offloading and weights[0].requires_grad:
weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
for i, weight in enumerate(ctx.weight_objects):
origin_weights[i] = ctx.weight_objects[i]
ctx.weight_objects[i] = None
if ctx.fuse_wgrad_accumulation:
for i in range(N):
origin_weights[i].main_grad = main_grads[i]
# Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
......@@ -369,7 +359,7 @@ class _GroupedLinear(torch.autograd.Function):
)
for weight, quantizer in zip(weights, ctx.weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensorBase):
if quantizer is not None and isinstance(weight, QuantizedTensorStorage):
weight.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
......@@ -433,7 +423,11 @@ class _GroupedLinear(torch.autograd.Function):
use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases,
use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
accumulate=(
accumulate_wgrad_into_param_main_grad
if not getattr(weights[0], "overwrite_main_grad", False)
else False
),
)
# WGRAD
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
......@@ -555,7 +549,9 @@ class GroupedLinear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
......@@ -772,7 +768,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced)
"""
assert not isinstance(
inp, QuantizedTensorBase
inp, QuantizedTensorStorage
), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
......@@ -907,16 +903,17 @@ class GroupedLinear(TransformerEngineBaseModule):
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors):
if not self.fp8 and any(isinstance(w, QuantizedTensorStorage) for w in weight_tensors):
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensorBase) else w for w in weight_tensors
w.dequantize() if isinstance(w, QuantizedTensorStorage) else w
for w in weight_tensors
]
return weight_tensors
......
......@@ -16,6 +16,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_experimental
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace,
......@@ -26,9 +27,10 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ..fp8 import FP8GlobalStateManager
from ..quantization import FP8GlobalStateManager
from ..utils import (
assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
cast_if_needed,
clear_tensor_data,
divide,
......@@ -57,7 +59,7 @@ from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, WeightGradStore
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase,
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
restore_from_saved,
......@@ -65,8 +67,8 @@ from ..tensor.quantized_tensor import (
from ...debug.pytorch.debug_state import TEDebugState
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
......@@ -144,6 +146,8 @@ class _LayerNormLinear(torch.autograd.Function):
if ub_name is not None:
nvtx_label = f"{nvtx_label}.{ub_name}"
with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Make sure input dimensions are compatible
out_features, in_features = weight.shape
inp_shape = inp.shape
......@@ -153,6 +157,7 @@ class _LayerNormLinear(torch.autograd.Function):
inputmat = inp
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer)
# Cast for native AMP
nvtx_range_push(f"{nvtx_label}.norm_input_cast")
......@@ -166,7 +171,6 @@ class _LayerNormLinear(torch.autograd.Function):
weight_requires_grad = weight.requires_grad
backward_needs_input = is_grad_enabled and weight_requires_grad
with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Configure Userbuffers communication (comm+GEMM overlap)
if debug: # turn off userbuffers in debug mode
......@@ -199,11 +203,13 @@ class _LayerNormLinear(torch.autograd.Function):
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
experimental = is_experimental(input_quantizer)
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not experimental # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
)
# Apply normalization
......@@ -249,7 +255,8 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer = None
if fp8 or debug:
quantizer = input_quantizer
if not with_quantized_norm:
# experimental recipe doesn't need to support quantized AG
if not with_quantized_norm and not experimental:
ln_out = quantizer(ln_out)
quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
......@@ -280,7 +287,7 @@ class _LayerNormLinear(torch.autograd.Function):
weightmat = weight
quantized_weight = False
if fp8 or debug:
quantized_weight = not isinstance(weight, QuantizedTensorBase)
quantized_weight = not isinstance(weight, QuantizedTensorStorage)
# Configure quantizer
if weight_quantizer is not None:
......@@ -405,18 +412,18 @@ class _LayerNormLinear(torch.autograd.Function):
# Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input:
if isinstance(ln_out, QuantizedTensorBase):
if isinstance(ln_out, QuantizedTensorStorage):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if (
isinstance(ln_out, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
isinstance(ln_out, (MXFP8TensorStorage, Float8BlockwiseQTensorStorage))
or not ctx.ln_out_needs_gather
):
ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(weightmat, QuantizedTensorBase):
if isinstance(weightmat, QuantizedTensorStorage):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading:
......@@ -716,9 +723,9 @@ class _LayerNormLinear(torch.autograd.Function):
# --------------------------------------------------
# Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase):
if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(rowwise_usage=True)
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase):
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True)
# Choose whether to use GEMM kernel with split accumulator
......@@ -837,14 +844,14 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_work.wait()
ln_out_total_work = None
if ctx.fp8 or ctx.debug:
if isinstance(ln_out_total, QuantizedTensorBase):
if isinstance(ln_out_total, QuantizedTensorStorage):
ln_out_total.update_usage(columnwise_usage=True)
else:
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -880,7 +887,11 @@ class _LayerNormLinear(torch.autograd.Function):
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
"quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(weight, "overwrite_main_grad", False)
else False
),
"layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None),
......@@ -1037,7 +1048,7 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers
# if ctx.fp8 and not isinstance(weight, QuantizedTensorBase):
# if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage):
# _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return (
......@@ -1164,7 +1175,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
......@@ -1470,6 +1483,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)
# elif other recipes (mxfp8, etc)
def reset_layer_norm_parameters(self) -> None:
......@@ -1812,7 +1827,29 @@ class LayerNormLinear(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_linear."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# set input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
......
......@@ -18,6 +18,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_experimental
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace,
......@@ -28,7 +29,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ..fp8 import FP8GlobalStateManager
from ..quantization import FP8GlobalStateManager
from ..jit import (
bias_gelu_fused,
bgrad_dgelu_fused,
......@@ -41,6 +42,7 @@ from ..utils import (
init_method_constant,
cast_if_needed,
assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
clear_tensor_data,
requires_grad,
needs_quantized_gemm,
......@@ -65,11 +67,12 @@ from ..tensor.float8_tensor import (
Float8Tensor,
)
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import (
QuantizedTensorBase,
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
restore_from_saved,
......@@ -120,8 +123,14 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"swiglu": (tex.swiglu, tex.dswiglu, None),
}
# no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling: []
if recipe.float8_current_scaling() or recipe.float8_block_scaling():
# Per-tensor current scaling or fp8 blockwise scaling or custom quantization: []
# TODO(ksivaman): Fuse nvfp4 act once kernel is available.
if (
recipe.float8_current_scaling()
or recipe.float8_block_scaling()
or recipe.nvfp4()
or recipe.custom()
):
return {
"gelu": (tex.gelu, tex.dgelu, None),
"geglu": (tex.geglu, tex.dgeglu, None),
......@@ -218,6 +227,7 @@ class _LayerNormMLP(torch.autograd.Function):
inputmat = inp.view((-1, in_features))
if fp8:
assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight)
assert_dim_for_all_gather(inputmat, sequence_parallel, fc1_input_quantizer)
activation_func = _act_func(
activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
......@@ -265,11 +275,13 @@ class _LayerNormMLP(torch.autograd.Function):
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
experimental = is_experimental(fc1_input_quantizer)
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not experimental
)
# Apply normalization
......@@ -309,7 +321,8 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer = None
if fp8 or debug:
quantizer = fc1_input_quantizer
if not with_quantized_norm:
# experimental recipe doesn't need to support quantized AG
if not with_quantized_norm and not experimental:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag:
......@@ -447,10 +460,18 @@ class _LayerNormMLP(torch.autograd.Function):
act_out = fc2_input_quantizer(act_out)
else:
fc1_out, *_ = fc1_outputs
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling():
# tex.quantize does not support GELU fusion for blockwise.
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_block_scaling():
# tex.quantize does not support GELU fusion for blockwise
act_out = activation_func(fc1_out, None)
act_out = tex.quantize(act_out, fc2_input_quantizer)
elif recipe.custom():
# tex.quantize does not support custom quantizers
act_out = activation_func(fc1_out, None)
act_out = fc2_input_quantizer(act_out)
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)
else:
if fp8_calibration:
act_out = activation_func(fc1_out, None)
......@@ -521,9 +542,9 @@ class _LayerNormMLP(torch.autograd.Function):
if is_grad_enabled:
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(fc1_weight_final, QuantizedTensorBase):
if isinstance(fc1_weight_final, QuantizedTensorStorage):
fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensorBase):
if isinstance(fc2_weight_final, QuantizedTensorStorage):
fc2_weight_final.update_usage(columnwise_usage=True)
if cpu_offloading:
......@@ -555,6 +576,7 @@ class _LayerNormMLP(torch.autograd.Function):
if not fc2_weight.requires_grad:
clear_tensor_data(act_out)
act_out = None
tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
ln_weight,
......@@ -680,6 +702,7 @@ class _LayerNormMLP(torch.autograd.Function):
mu,
rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
......@@ -820,10 +843,10 @@ class _LayerNormMLP(torch.autograd.Function):
)
# Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase):
if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(rowwise_usage=True)
if ctx.fc2_weight_quantizer is not None and isinstance(
ctx.fc2_weight, QuantizedTensorBase
ctx.fc2_weight, QuantizedTensorStorage
):
ctx.fc2_weight.update_usage(columnwise_usage=True)
......@@ -905,14 +928,14 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.fp8 or ctx.debug:
if isinstance(act_out, QuantizedTensorBase):
if isinstance(act_out, QuantizedTensorStorage):
act_out.update_usage(columnwise_usage=True)
else:
ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True)
act_out = ctx.fc2_input_quantizer(act_out)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -932,7 +955,11 @@ class _LayerNormMLP(torch.autograd.Function):
else ctx.activation_dtype
),
"quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(fc1_weight, "overwrite_main_grad", False)
else False
),
"layout": "NT",
"out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
......@@ -1028,8 +1055,11 @@ class _LayerNormMLP(torch.autograd.Function):
) # activation in high precision
if ctx.fp8:
# TODO float8 blockwise current scaling has no bgrad fusion for now
if isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer):
# TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now
if (
isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer)
or ctx.fp8_recipe.custom()
):
fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0)
dact = ctx.fc1_grad_output_quantizer(dact)
else:
......@@ -1074,7 +1104,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Make sure required data is available
if ctx.fc1_weight_quantizer is not None and isinstance(
ctx.fc1_weight_quantizer, QuantizedTensorBase
ctx.fc1_weight_quantizer, QuantizedTensorStorage
):
ctx.fc1_weight.update_usage(columnwise_usage=True)
......@@ -1145,7 +1175,7 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total_work.wait()
ln_out_total_work = None
if ctx.fp8 or ctx.debug:
if isinstance(ln_out_total, QuantizedTensorBase):
if isinstance(ln_out_total, QuantizedTensorStorage):
ln_out_total.update_usage(columnwise_usage=True)
else:
ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -1155,7 +1185,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.fp8 or ctx.debug:
if isinstance(dact, QuantizedTensorBase):
if isinstance(dact, QuantizedTensorStorage):
dact.update_usage(columnwise_usage=True)
else:
ctx.fc1_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -1178,7 +1208,11 @@ class _LayerNormMLP(torch.autograd.Function):
else ctx.activation_dtype
),
"quantization_params": ctx.fc1_grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(fc2_weight, "overwrite_main_grad", False)
else False
),
"layout": "NT",
"out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
......@@ -1486,7 +1520,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias for FC2, but
instead return the bias value during the forward pass together with the
......@@ -1718,6 +1754,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_layer_norm_parameters(self) -> None:
......@@ -1937,7 +1975,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
fc2_input_quantizer.set_usage(
rowwise=True,
columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)),
columnwise=isinstance(
fc2_input_quantizer,
(MXFP8Quantizer, Float8BlockQuantizer, NVFP4Quantizer),
),
)
fc1_input_quantizer.internal = True
if fp8_output:
......@@ -2142,7 +2183,29 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_mlp."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.set_parallel_mode:
# fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.set_parallel_mode:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
return [self.fc1_weight, self.fc2_weight]
......
......@@ -27,7 +27,7 @@ from .base import (
_2X_ACC_WGRAD,
)
from ._common import noop_cat, WeightGradStore
from ..fp8 import FP8GlobalStateManager
from ..quantization import FP8GlobalStateManager
from ..utils import (
cast_if_needed,
clear_tensor_data,
......@@ -36,6 +36,7 @@ from ..utils import (
requires_grad,
needs_quantized_gemm,
assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
nvtx_range_pop,
nvtx_range_push,
get_activation_offloading,
......@@ -60,13 +61,14 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase,
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
restore_from_saved,
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.utils import is_experimental
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState
......@@ -154,6 +156,9 @@ class _Linear(torch.autograd.Function):
ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.AG
# experimental recipe check
experimental = is_experimental(input_quantizer) or is_experimental(weight_quantizer)
# ------------------------------------------------------
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
......@@ -164,6 +169,7 @@ class _Linear(torch.autograd.Function):
own_quantized_input = False
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer)
if save_original_input:
assert not isinstance(
input_quantizer, Float8Quantizer
......@@ -175,7 +181,7 @@ class _Linear(torch.autograd.Function):
if fp8 or debug:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if not isinstance(inputmat, QuantizedTensorBase):
if not isinstance(inputmat, QuantizedTensorStorage) and not experimental:
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance(
......@@ -213,7 +219,7 @@ class _Linear(torch.autograd.Function):
else: # Do not all-gather input tensor
if fp8 or debug:
if isinstance(inputmat, QuantizedTensorBase):
if isinstance(inputmat, QuantizedTensorStorage):
inputmat.update_usage(rowwise_usage=True)
else:
if input_quantizer is None:
......@@ -369,7 +375,7 @@ class _Linear(torch.autograd.Function):
if (
backward_needs_input
and own_quantized_input
and isinstance(inputmat, QuantizedTensorBase)
and isinstance(inputmat, QuantizedTensorStorage)
):
if (
ctx.backward_input_needs_gather
......@@ -388,7 +394,7 @@ class _Linear(torch.autograd.Function):
# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensorBase):
if isinstance(weightmat, QuantizedTensorStorage):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading and saved_inputmat is not None:
......@@ -401,7 +407,7 @@ class _Linear(torch.autograd.Function):
ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group,
saved_inputmat,
weightmat if fp8 and not isinstance(weight, QuantizedTensorBase) else None,
weightmat if fp8 and not isinstance(weight, QuantizedTensorStorage) else None,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
......@@ -471,6 +477,7 @@ class _Linear(torch.autograd.Function):
ctx.main_grad_func = lambda: weight.main_grad
ctx.debug = debug
ctx.experimental = experimental
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = bias is not None
......@@ -637,10 +644,10 @@ class _Linear(torch.autograd.Function):
inputmat_total_work = None
if ctx.requires_wgrad:
if ctx.fp8 or ctx.debug:
if isinstance(inputmat, QuantizedTensorBase):
if isinstance(inputmat, QuantizedTensorStorage):
# Input tensor is already quantized
pass
elif ctx.debug:
elif ctx.debug or ctx.experimental:
# Debug quantizer will be applied immediately before wgrad GEMM
pass
else:
......@@ -656,7 +663,7 @@ class _Linear(torch.autograd.Function):
quantizer.set_usage(rowwise=False, columnwise=True)
inputmat = quantizer(inputmat)
else:
if isinstance(inputmat, QuantizedTensorBase):
if isinstance(inputmat, QuantizedTensorStorage):
inputmat = inputmat.dequantize(dtype=ctx.activation_dtype)
else:
inputmat = cast_if_needed(inputmat, ctx.activation_dtype)
......@@ -701,9 +708,11 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad:
# Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase):
if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(rowwise_usage=True)
if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase):
if ctx.weight_quantizer is not None and isinstance(
weight_fp8, QuantizedTensorStorage
):
weight_fp8.update_usage(columnwise_usage=True)
# Choose whether to use GEMM kernel with split accumulator
......@@ -729,6 +738,7 @@ class _Linear(torch.autograd.Function):
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weight_fp8,
......@@ -786,7 +796,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work.wait()
inputmat_total_work = None
if ctx.fp8 or ctx.debug:
if isinstance(inputmat_total, QuantizedTensorBase):
if isinstance(inputmat_total, QuantizedTensorStorage):
inputmat_total.update_usage(columnwise_usage=True)
else:
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -828,7 +838,7 @@ class _Linear(torch.autograd.Function):
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -864,7 +874,11 @@ class _Linear(torch.autograd.Function):
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
"quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(weight, "overwrite_main_grad", False)
else False
),
"layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None),
......@@ -984,7 +998,7 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers
if ctx.fp8 and not isinstance(weight, QuantizedTensorBase):
if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage):
_fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return (
wgrad,
......@@ -1086,7 +1100,9 @@ class Linear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
......@@ -1363,6 +1379,8 @@ class Linear(TransformerEngineBaseModule):
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_parameters(self, defer_init=False):
......@@ -1452,7 +1470,6 @@ class Linear(TransformerEngineBaseModule):
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
......@@ -1557,7 +1574,7 @@ class Linear(TransformerEngineBaseModule):
for name, q in zip(names, original_quantizers)
)
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
......@@ -1693,6 +1710,28 @@ class Linear(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# customize input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8 and not self.fp8_calibration:
......
......@@ -11,19 +11,19 @@ import torch
from transformer_engine_torch import FP8TensorMeta
from .. import torch_version
from ..fp8 import FP8GlobalStateManager
from ..quantization import FP8GlobalStateManager
from ..tensor.float8_tensor import Float8Tensor
from ..tensor.quantized_tensor import QuantizedTensorBase
from ..tensor.quantized_tensor import QuantizedTensorStorage
from ..utils import canonicalize_dtype
def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorBase) -> bool:
def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool:
"""Check if tensor is a quantized tensor"""
return isinstance(tensor, QuantizedTensorBase)
return isinstance(tensor, QuantizedTensorStorage)
def maybe_dequantize(
tensor: torch.Tensor | QuantizedTensorBase, dtype: torch.dtype | None = None
tensor: torch.Tensor | QuantizedTensorStorage, dtype: torch.dtype | None = None
) -> torch.Tensor:
"""Dequantize tensor to given dtype or just convert if not a quantized tensor"""
if is_quantized_tensor(tensor):
......
......@@ -4,7 +4,19 @@
"""Single tensor operations supported by the operation fuser."""
from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU
from .activation import (
GELU,
GEGLU,
QGELU,
QGEGLU,
ReLU,
ReGLU,
SReLU,
SReGLU,
SiLU,
SwiGLU,
ClampedSwiGLU,
)
from .add_extra_input import AddExtraInput
from .all_gather import AllGather
from .all_reduce import AllReduce
......
......@@ -28,6 +28,7 @@ __all__ = [
"SReGLU",
"SiLU",
"SwiGLU",
"ClampedSwiGLU",
]
......@@ -392,3 +393,38 @@ class SwiGLU(_ActivationOperation):
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dswiglu(*args, **kwargs)
class ClampedSwiGLU(_ActivationOperation):
r"""GPT-OSS
Implementation based on `GPT-OSS<https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250>`__.
This activation has two differences compared to the original SwiGLU
1. Both gate and pre-activations are clipped based on parameter limit.
2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation.
.. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt
from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor.
Parameters
----------
limit: float
The clamp limit.
alpha: float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input: bool, default = False
Quantize input tensor when caching for use in the backward pass.
"""
def __init__(
self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False
):
super().__init__(cache_quantized_input=cache_quantized_input)
self.limit = limit
self.alpha = alpha
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs)
......@@ -19,7 +19,7 @@ from ...distributed import (
gather_along_first_dim,
reduce_scatter_along_first_dim,
)
from ...fp8 import FP8GlobalStateManager, Recipe
from ...quantization import FP8GlobalStateManager, Recipe
from ...module.base import (
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
......@@ -29,7 +29,7 @@ from ...module.base import (
)
from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from ...utils import (
canonicalize_device,
canonicalize_dtype,
......@@ -80,7 +80,9 @@ class BasicLinear(BasicOperation):
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful. This is primarily intented to integrate with
Megatron-LM.
Megatron-LM. This argument along with weight tensor having
attribute 'overwrite_main_grad' set to True will overwrite
`main_grad` instead of accumulating.
userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly
......@@ -301,8 +303,8 @@ class BasicLinear(BasicOperation):
"Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. "
"This is most likely because the weight was initialized "
"within fp8_model_init, but the forward pass was not "
"performed within fp8_autocast."
"within quantized_model_init, but the forward pass was not "
"performed within autocast."
)
quantizer.set_usage(
rowwise=True,
......@@ -322,6 +324,20 @@ class BasicLinear(BasicOperation):
if self.weight.device.type == "meta":
self.reset_parameters()
def pre_fuser_forward(self, *, requires_grad: bool) -> None:
super().pre_fuser_forward(requires_grad=requires_grad)
if FP8GlobalStateManager.is_fp8_enabled():
# Configure quantizer usages
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
weight_requires_grad = requires_grad and self.weight.requires_grad
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
grad_output_quantizer = self.get_quantizer("backward", 0)
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_state(recipe=recipe)
......@@ -352,6 +368,35 @@ class BasicLinear(BasicOperation):
and not getattr(self, "_with_quantized_weight", False)
)
# Recipe-specific configuration
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
if recipe is not None:
if recipe.float8_current_scaling():
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon
if getattr(self, "sequence_parallel", False):
tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None)
if tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
elif tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
if recipe.nvfp4():
if getattr(self, "sequence_parallel", False):
tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None)
if tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
elif tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
@staticmethod
def _functional_forward(
input: torch.Tensor, # pylint: disable=redefined-builtin
......@@ -568,7 +613,7 @@ class BasicLinear(BasicOperation):
# Prepare input tensor for backward pass
if weight_requires_grad:
if with_quantized_compute and is_quantized_tensor(x_local):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
......@@ -731,7 +776,7 @@ class BasicLinear(BasicOperation):
if with_quantized_compute:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(columnwise=True)
input_quantizer.set_usage(rowwise=False, columnwise=True)
if with_x_all_gather:
x, x_async = gather_along_first_dim(
x_local,
......@@ -912,34 +957,13 @@ class BasicLinear(BasicOperation):
input_requires_grad = ctx.requires_grad
weight_requires_grad = ctx.requires_grad and self.weight.requires_grad
# FP8 metadata
# Quantizers
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = self.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
# Configure quantizers
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False)
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Get autocast dtype if needed
if torch.is_autocast_enabled():
......@@ -997,6 +1021,7 @@ class BasicLinear(BasicOperation):
weight_param = self.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
......
......@@ -11,7 +11,7 @@ import torch
import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_autocast_dtype, maybe_dequantize
from ..op import BasicOperation, OperationContext
......@@ -56,7 +56,7 @@ class Dropout(BasicOperation):
out = input_
elif impl == "fused":
x = input_
if not isinstance(x, Float8TensorBase):
if not isinstance(x, Float8TensorStorage):
x = maybe_dequantize(x, dtype=dtype)
out, mask = tex.dropout_fwd(x, self.dropout_probability)
elif impl == "unfused":
......
......@@ -9,7 +9,7 @@ from typing import Optional
import torch
from ...fp8 import FP8GlobalStateManager
from ...quantization import FP8GlobalStateManager
from .._common import is_quantized_tensor
from ..op import BasicOperation, OperationContext
from ...tensor import Quantizer
......@@ -18,8 +18,8 @@ from ...tensor import Quantizer
class Quantize(BasicOperation):
"""Quantize tensor data
Uses FP8 recipe from `fp8_autocast` context. When called outside
of an `fp8_autocast` context, this is an identity operation.
Uses recipe from `autocast` context. When called outside
of an `autocast` context, this is an identity operation.
Parameters
----------
......
......@@ -10,7 +10,7 @@ from typing import Optional
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import Recipe
from transformer_engine.pytorch.quantization import Recipe
from transformer_engine.pytorch.ops.basic import Bias
from transformer_engine.pytorch.ops.basic.activation import (
_ActivationOperation,
......
......@@ -59,6 +59,7 @@ class BackwardLinearAdd(FusedOperation):
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
......
......@@ -60,6 +60,7 @@ class BackwardLinearScale(FusedOperation):
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
......
......@@ -11,7 +11,7 @@ from typing import Any, Optional
import torch
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...fp8 import FP8GlobalStateManager
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import BasicLinear, Bias
from ..op import FusedOperation, FusibleOperation, OperationContext
......@@ -85,7 +85,7 @@ class ForwardLinearBiasActivation(FusedOperation):
input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
# Quantizers
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
......
......@@ -11,7 +11,7 @@ from typing import Any, Optional
import torch
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...fp8 import FP8GlobalStateManager
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, Bias
from ..op import FusedOperation, FusibleOperation, OperationContext
......@@ -79,7 +79,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
# Quantizers
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = None
......
......@@ -11,7 +11,7 @@ from typing import Any, Optional
import torch
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...fp8 import FP8GlobalStateManager
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, ConstantScale
from ..op import (
......@@ -58,7 +58,7 @@ class ForwardLinearScaleAdd(FusedOperation):
input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
# Quantizers
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = None
......
......@@ -523,6 +523,7 @@ class UserbuffersBackwardLinear(FusedOperation):
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
......
......@@ -14,7 +14,7 @@ from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import get_distributed_world_size
from ...fp8 import FP8GlobalStateManager
from ...quantization import FP8GlobalStateManager
from ...module.base import (
fill_userbuffers_buffer_for_all_gather,
get_ub,
......@@ -23,7 +23,7 @@ from ...module.base import (
)
from ...tensor.quantized_tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_dequantize, is_quantized_tensor
from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import (
......@@ -267,7 +267,7 @@ class UserbuffersForwardLinear(FusedOperation):
# Prepare input tensor for backward pass
if weight_requires_grad:
if with_quantized_compute and is_quantized_tensor(x_local):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather):
if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
......
......@@ -11,7 +11,7 @@ import itertools
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe, DelayedScaling
from transformer_engine.pytorch.quantization import FP8GlobalStateManager, Recipe, DelayedScaling
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusibleOperation,
......@@ -472,6 +472,10 @@ class OperationFuser:
# Attempt to fuse operations if neccesary
self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs)
# Initialization before forward
for idx, op in enumerate(self._basic_ops):
op.pre_fuser_forward(requires_grad=idx >= self.first_op_requiring_backward)
# Fuser forward pass
if is_grad_enabled:
forward_func = _OperationFuserAutogradFunction.apply
......
......@@ -14,10 +14,10 @@ from typing import Any, Optional
import torch
from transformer_engine.common.recipe import Recipe
from ..fp8 import (
from ..quantization import (
FP8GlobalStateManager,
RecipeState,
fp8_autocast,
autocast,
)
from ..tensor import Quantizer
......@@ -65,6 +65,13 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def pre_first_fuser_forward(self) -> None:
"""Preprocessing before first fuser forward pass"""
def pre_fuser_forward(
self,
*,
requires_grad: bool, # pylint: disable=unused-argument
) -> None:
"""Preprocessing before fuser forward pass"""
def get_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input tensor"""
......@@ -588,6 +595,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
extra[key] = val
state[mode]["extra_fp8_variables"] = extra
if not state:
return torch.empty(0, dtype=torch.uint8)
# Serialize state into byte tensor
torch.cuda.synchronize()
state_serialized = bytearray(pickle.dumps(state))
......@@ -624,7 +634,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Get op's quantizer state, initializing if needed
if self._fp8_metas is None or self._fp8_metas[mode] is None:
with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
with autocast(recipe=state[mode]["recipe"]):
self.reset_recipe_state(recipe=state[mode]["recipe"])
fp8_meta = self._fp8_metas[mode]
......@@ -710,6 +720,10 @@ class FusedOperation(FusibleOperation):
for op in self.basic_ops:
op.pre_first_fuser_forward()
def pre_fuser_forward(self, *, requires_grad: bool) -> None:
for op in self.basic_ops:
op.pre_fuser_forward(requires_grad=requires_grad)
def forward(
self,
input: torch.Tensor, # pylint: disable=redefined-builtin
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment