Commit 44740c6c authored by yuguo's avatar yuguo
Browse files

Merge commit '7a9a0825' of...

Merge commit '7a9a0825' of https://github.com/NVIDIA/TransformerEngine
parents 8113d9e0 7a9a0825
......@@ -78,6 +78,7 @@ from ..tensor.quantized_tensor import (
from ..cpp_extensions import (
general_gemm,
)
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ...debug.pytorch.utils import any_feature_enabled
from ...debug.pytorch.debug_state import TEDebugState
......@@ -86,16 +87,16 @@ __all__ = ["LayerNormMLP"]
def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
if recipe is None:
# bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# bf16 (recipe is None):
return {
"gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu),
"relu": (tex.relu, tex.drelu, tex.dbias_drelu),
"gelu": (tex.gelu, tex.dgelu, None),
"relu": (tex.relu, tex.drelu, None),
"geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu),
"qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu),
"srelu": (tex.srelu, tex.dsrelu, None),
}
if recipe.delayed() or recipe.mxfp8():
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
......@@ -553,8 +554,20 @@ class _LayerNormMLP(torch.autograd.Function):
)
if fuse_wgrad_accumulation:
ctx.fc1_main_grad = fc1_weight.main_grad if fc1_weight.requires_grad else None
ctx.fc2_main_grad = fc2_weight.main_grad if fc2_weight.requires_grad else None
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
if hasattr(fc1_weight, "__fsdp_param__") and hasattr(fc2_weight, "__fsdp_param__"):
# MCore FSDP creates main_grad lazily before backward
ctx.fc1_main_grad_func = (
fc1_weight.get_main_grad if fc1_weight.requires_grad else lambda: None
)
ctx.fc2_main_grad_func = (
fc2_weight.get_main_grad if fc2_weight.requires_grad else lambda: None
)
else:
ctx.fc1_main_grad_func = lambda: fc1_weight.main_grad
ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
......@@ -654,14 +667,14 @@ class _LayerNormMLP(torch.autograd.Function):
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
fc1_weight_main_grad = (
ctx.fc1_main_grad
ctx.fc1_main_grad_func()
if fc1_weight is not None
and ctx.fuse_wgrad_accumulation
and ctx.fc1_weight_requires_grad
else None
)
fc2_weight_main_grad = (
ctx.fc2_main_grad
ctx.fc2_main_grad_func()
if origin_fc2_weight is not None
and ctx.fuse_wgrad_accumulation
and ctx.fc2_weight_requires_grad
......@@ -1727,6 +1740,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
if is_in_onnx_export_mode():
return self.onnx_forward(inp)
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
......@@ -1917,6 +1932,89 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer,
)
def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios.
"""
from ..export import onnx_layernorm, onnx_gemm
assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export"
assert_warmed_up(self)
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
output_quantizer,
*_,
) = self._get_quantizers(False)
inp_dtype = inp.dtype
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
# layernorm + fp8 cast
ln_out, ln_out_return = onnx_layernorm(
inp,
self.layer_norm_weight,
self.layer_norm_bias,
self.eps,
self.normalization,
self.zero_centered_gamma,
inp_dtype,
self.return_layernorm_output,
fc1_input_quantizer,
)
if fc1_weight_quantizer is not None:
fc1_weight_q = fc1_weight_quantizer.onnx_quantize(fc1_weight)
fc1_weight = fc1_weight_quantizer.onnx_dequantize(fc1_weight_q)
fc1_weight = fc1_weight.to(inp_dtype)
fc1_out = onnx_gemm(fc1_weight, ln_out, fc1_bias)
fc1_out = fc1_out.to(torch.float32) # activation is computed in fp32
activation_map = {
"gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"relu": torch.nn.functional.relu,
"geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh")
* x.chunk(2, -1)[1],
"qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"srelu": torch.nn.functional.softplus,
}
if self.activation not in activation_map:
raise ValueError(f"Unsupported activation in onnx export: {self.activation}")
act_out = activation_map[self.activation](fc1_out)
if fc2_weight_quantizer is not None:
fc2_weight_q = fc2_weight_quantizer.onnx_quantize(fc2_weight)
fc2_weight = fc2_weight_quantizer.onnx_dequantize(fc2_weight_q)
fc2_weight = fc2_weight.to(inp_dtype)
if fc2_input_quantizer is not None:
act_out_q = fc2_input_quantizer.onnx_quantize(act_out)
act_out = fc2_input_quantizer.onnx_dequantize(act_out_q)
act_out = act_out.to(inp_dtype)
fc2_out = onnx_gemm(fc2_weight, act_out, fc2_bias)
if output_quantizer is not None:
raise NotImplementedError("ONNX export of quantized output is not supported")
if self.return_layernorm_output:
if self.return_bias:
return fc2_out, fc2_bias.to(inp_dtype), ln_out_return
return fc2_out, ln_out_return
if self.return_bias:
return fc2_out, fc2_bias.to(inp_dtype)
return fc2_out
def _get_debug_quantizers(self, fp8_output):
from ...debug.pytorch.debug_quantization import DebugQuantizer
......
......@@ -68,6 +68,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
......@@ -117,6 +118,7 @@ class _Linear(torch.autograd.Function):
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
save_original_input: bool = False,
debug: Optional[bool] = False,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
......@@ -157,6 +159,11 @@ class _Linear(torch.autograd.Function):
own_quantized_input = False
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
if save_original_input:
assert not isinstance(
input_quantizer, Float8Quantizer
), "DelayedScaling recipe is not supported with save_original_input"
if with_input_all_gather_nccl or ub_overlap_ag_fprop: # All-gather input tensor
# Cast local input tensor if needed
......@@ -164,7 +171,9 @@ class _Linear(torch.autograd.Function):
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if not isinstance(inputmat, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
input_quantizer.set_usage(
rowwise=True, columnwise=backward_needs_input and not save_original_input
)
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
......@@ -201,7 +210,9 @@ class _Linear(torch.autograd.Function):
else:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
input_quantizer.set_usage(
rowwise=True, columnwise=backward_needs_input and not save_original_input
)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
else:
......@@ -330,6 +341,9 @@ class _Linear(torch.autograd.Function):
# ------------------------------------------------------
if is_grad_enabled:
if save_original_input:
inputmat = inp
ctx.weight_quantizer = weight_quantizer
saved_inputmat = None
......@@ -338,6 +352,7 @@ class _Linear(torch.autograd.Function):
)
if backward_needs_input:
if not save_original_input:
if own_quantized_input and isinstance(inputmat, QuantizedTensorBase):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
......@@ -398,7 +413,14 @@ class _Linear(torch.autograd.Function):
ctx.grad_output_quantizer = grad_output_quantizer
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
if fuse_wgrad_accumulation and weight.requires_grad:
ctx.main_grad = weight.main_grad
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
if hasattr(weight, "__fsdp_param__"):
# MCore FSDP creates main_grad lazily before backward
ctx.main_grad_func = weight.get_main_grad
else:
ctx.main_grad_func = lambda: weight.main_grad
ctx.debug = debug
ctx.cpu_offloading = cpu_offloading
......@@ -454,7 +476,7 @@ class _Linear(torch.autograd.Function):
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
ctx.main_grad
ctx.main_grad_func()
if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
else None
)
......@@ -550,6 +572,24 @@ class _Linear(torch.autograd.Function):
# --------------------------------------------------
inputmat_total = None
inputmat_total_work = None
if ctx.requires_wgrad:
input_is_quantized = isinstance(inputmat, QuantizedTensorBase)
if ctx.fp8 or ctx.debug:
if not input_is_quantized:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
quantizer.set_usage(
rowwise=True,
columnwise=not ctx.backward_input_needs_gather,
)
else:
quantizer.set_usage(rowwise=False, columnwise=True)
inputmat = quantizer(inputmat)
else:
if input_is_quantized:
inputmat = inputmat.dequantize(dtype=ctx.activation_dtype)
else:
inputmat = cast_if_needed(inputmat, ctx.activation_dtype)
if ctx.backward_input_needs_gather:
quantizer = None
if ctx.fp8 or ctx.debug:
......@@ -894,6 +934,7 @@ class _Linear(torch.autograd.Function):
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
None, # save_original_input
None, # debug
)
......@@ -976,6 +1017,11 @@ class Linear(TransformerEngineBaseModule):
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
is used.
save_original_input : bool, default = `False`
If set to `True`, always saves the original input tensor rather than the
cast tensor. In some scenarios, the input tensor is used by multiple modules,
and saving the original input tensor may reduce the memory usage.
Cannot work with FP8 DelayedScaling recipe.
"""
def __init__(
......@@ -1003,6 +1049,7 @@ class Linear(TransformerEngineBaseModule):
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
save_original_input: bool = False,
name: Optional[str] = None,
) -> None:
super().__init__()
......@@ -1017,6 +1064,7 @@ class Linear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self.symmetric_ar_type = symmetric_ar_type
self.save_original_input = save_original_input
self.name = name
if TEDebugState.debug_enabled:
......@@ -1275,6 +1323,9 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output)
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
......@@ -1298,13 +1349,7 @@ class Linear(TransformerEngineBaseModule):
allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp:
# Get concatenated weight and bias tensors
unfused_weights = self._get_weight_tensors()
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
else:
bias_tensor = None
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
......@@ -1370,6 +1415,7 @@ class Linear(TransformerEngineBaseModule):
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(*args)
......@@ -1417,6 +1463,95 @@ class Linear(TransformerEngineBaseModule):
for name, q in zip(names, original_quantizers)
)
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
return unfused_weights
def _get_weight_and_bias_tensors(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Get concatenated weight and bias tensors
unfused_weights = self._get_weight_tensors()
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
else:
bias_tensor = None
return weight_tensor, bias_tensor
def onnx_forward(
self,
inp: torch.Tensor,
fp8_output: bool,
) -> torch.Tensor:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios.
"""
from ..export import onnx_gemm
assert_warmed_up(self)
assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export."
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
(
input_quantizer,
weight_quantizer,
output_quantizer,
*_,
) = self._get_quantizers(fp8_output, False)
inp_dtype = inp.dtype
if input_quantizer is not None:
inp_q = input_quantizer.onnx_quantize(inp)
inp = input_quantizer.onnx_dequantize(inp_q)
inp = inp.to(inp_dtype)
if weight_quantizer is not None:
weight_q = weight_quantizer.onnx_quantize(weight_tensor)
weight_tensor = weight_quantizer.onnx_dequantize(weight_q)
if bias_tensor is not None:
bias_tensor = bias_tensor.to(inp_dtype)
weight_tensor = weight_tensor.to(inp_dtype)
if self.apply_bias:
output = onnx_gemm(weight_tensor, inp, bias_tensor)
else:
output = onnx_gemm(weight_tensor, inp, None)
if output_quantizer is not None:
raise NotImplementedError("ONNX export of quantized output is not supported")
if self.return_bias:
return output, bias_tensor
return output
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
assert (
......@@ -1464,23 +1599,6 @@ class Linear(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
return unfused_weights
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
File containing torch.ops extensions and their corresponding ONNX symbolic functions.
Many transformer engine layers rely on custom calls from the transformer_engine_torch module, making ONNX export challenging because:
1. They often accept Python objects (quantizers), which ONNX does not support.
2. They are complex, incorporating fusions and precomputing certain values for backward passes—mechanisms unnecessary for ONNX export.
For these reasons, we introduce onnx_forward methods in each layer that are simpler and
primarily leverage torch operators with known ONNX symbolic functions.
These methods avoid fusions and backward pass precomputations.
The main considerations are quantization—which PyTorch does not natively support, so we need to implement onnx symbolic functions on our own.
Since ONNX does not yet support quantization, operators from TensorRT are employed.
The primary goal of ONNX export is to enable inference compatibility with TensorRT.
"""
from typing import Tuple
import math
import torch
import onnxscript
from onnxscript import opset18 as op
from onnx import defs
import transformer_engine_torch as tex
from .tensor.float8_tensor import Float8Quantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer
from .constants import MXFP8_BLOCK_SCALING_SIZE
from .utils import round_up_to_nearest_multiple
from .export import is_in_onnx_export_mode
trt_opset = onnxscript.values.Opset(
"trt", version=1
) # opset from TensorRT which supports FP8 quantization
# ONNX GEMM for inference
def onnx_gemm(weight: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""ONNX GEMM used for inference."""
reshaped_inp = inp.reshape(-1, inp.shape[-1])
out = torch_onnx_gemm_inf_op(weight, reshaped_inp, bias)
return out.reshape(inp.shape[:-1] + (-1,))
@torch.library.custom_op("tex::gemm_inf", mutates_args=[])
def torch_onnx_gemm_inf_op(
weight: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
"""Gemm used for inference -- weight is transposed"""
out = inp @ weight.T
if bias is not None:
out = out + bias
return out
@torch_onnx_gemm_inf_op.register_fake
def _(weight, inp, bias):
"""Fake gemm used for inference."""
out = inp @ weight.T
if bias is not None:
out = out + bias
return out
def onnx_gemm_inf_symbolic(
weight: onnxscript.onnx_types.TensorType,
inp: onnxscript.onnx_types.TensorType,
bias: onnxscript.onnx_types.TensorType,
) -> onnxscript.onnx_types.TensorType:
"""Symbolic gemm used for inference."""
return op.Gemm(inp, weight, bias, transA=0, transB=1)
# ONNX FP8 Quantization
@torch.library.custom_op("tex::fp8_quantize", mutates_args=[])
def onnx_quantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor:
"""Quantize to Float8Tensor used for inference."""
scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device)
amax_tensor = torch.tensor([1], dtype=torch.float32, device=tensor.device)
quantizer = Float8Quantizer(scale_tensor, amax_tensor, tex.DType.kFloat8E4M3)
return quantizer.quantize(tensor)._data
@onnx_quantize_fp8_op.register_fake
def _(tensor, *_):
"""Fake quantize to Float8Tensor used for inference."""
return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device)
def onnx_quantize_fp8_symbolic(
tensor: onnxscript.onnx_types.TensorType,
scale: float,
) -> onnxscript.onnx_types.UINT8:
"""Symbolic quantize used for inference."""
scale_inv = op.Constant(value_float=1 / scale)
return TRT_FP8QuantizeLinear(tensor, scale_inv)
# Define the schema for the custom operator
schema = defs.OpSchema(
name="TRT_FP8QuantizeLinear",
domain="trt",
since_version=1,
doc="TRT FP8 Quantize Linear used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"),
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for quantization"),
],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")],
)
TRT_FP8QuantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_FP8QuantizeLinear", op_schema=schema
)
# ONNX FP8 Dequantization
@torch.library.custom_op("tex::fp8_dequantize", mutates_args=[])
def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor:
"""Dequantize from Float8Tensor used for inference."""
scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device)
quantizer = Float8Quantizer(
scale_tensor, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3
)
quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32)
return quantizer_tensor.dequantize()
@onnx_dequantize_fp8_op.register_fake
def _(tensor: torch.Tensor, _) -> torch.Tensor:
"""Fake dequantize from Float8Tensor used for inference."""
return torch.empty(tensor.shape, dtype=torch.float32, device=tensor.device)
def onnx_dequantize_fp8_symbolic(
tensor: onnxscript.onnx_types.TensorType, scale: float
) -> onnxscript.onnx_types.TensorType:
"""Symbolic dequantize from Float8Tensor used for inference."""
scale_inv = op.Constant(value_float=1 / scale)
return TRT_FP8DequantizeLinear(tensor, scale_inv)
schema = defs.OpSchema(
name="TRT_FP8DequantizeLinear",
domain="trt",
since_version=1,
doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"),
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for dequantization"),
],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")],
)
TRT_FP8DequantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema
)
# ONNX MXFP8 Quantization
@torch.library.custom_op("tex::mxfp8_quantize", mutates_args=[])
def onnx_quantize_mxfp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantize to MXFP8Tensor used for inference."""
quantizer = MXFP8Quantizer(tex.DType.kFloat8E4M3)
quantized_tensor = quantizer(tensor)
return quantized_tensor._rowwise_data, quantized_tensor._rowwise_scale_inv
@onnx_quantize_mxfp8_op.register_fake
def _(tensor: torch.Tensor):
"""Fake quantize to MXFP8Tensor used for inference."""
mxfp8_scale_shape = [
round_up_to_nearest_multiple(math.prod(tensor.shape[:-1]), 128),
round_up_to_nearest_multiple(tensor.shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
]
return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device), torch.empty(
mxfp8_scale_shape, dtype=torch.uint8, device=tensor.device
)
def onnx_quantize_mxfp8_symbolic(
tensor: onnxscript.onnx_types.TensorType,
) -> Tuple[onnxscript.onnx_types.TensorType, onnxscript.onnx_types.TensorType]:
"""Symbolic quantize to MXFP8Tensor used for inference."""
tensor_out, scale_inv_out = TRT_MXFP8QuantizeLinear(tensor)
return tensor_out, scale_inv_out
schema = defs.OpSchema(
name="TRT_MXFP8QuantizeLinear",
domain="trt",
since_version=1,
doc="TRT MXFP8 Quantize Linear used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"),
],
outputs=[
defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor"),
defs.OpSchema.FormalParameter(
"scale_inv", "tensor(uint8)", "Scale factor for quantization"
),
],
)
TRT_MXFP8QuantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_MXFP8QuantizeLinear", op_schema=schema
)
# ONNX MXFP8 Dequantization
@torch.library.custom_op("tex::mxfp8_dequantize", mutates_args=[])
def onnx_dequantize_mxfp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor:
"""Dequantize from MXFP8Tensor used for inference."""
quantizer = MXFP8Quantizer(tex.DType.kFloat8E4M3)
quantizer_tensor = quantizer.create_tensor_from_data(
tensor, scale_inv, fake_dtype=torch.float32
)
return quantizer_tensor.dequantize()
@onnx_dequantize_mxfp8_op.register_fake
def _(tensor: torch.Tensor, _):
"""Fake dequantize from MXFP8Tensor used for inference."""
return torch.empty(tensor.shape, dtype=torch.float32, device=tensor.device)
def onnx_dequantize_mxfp8_symbolic(
tensor: onnxscript.onnx_types.TensorType, scale_inv: onnxscript.onnx_types.TensorType
) -> onnxscript.onnx_types.TensorType:
"""Symbolic dequantize from MXFP8Tensor used for inference."""
return TRT_MXFP8DequantizeLinear(tensor, scale_inv)
schema = defs.OpSchema(
name="TRT_MXFP8DequantizeLinear",
domain="trt",
since_version=1,
doc="TRT MXFP8 Dequantize Linear from MXFP8Tensor used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"),
defs.OpSchema.FormalParameter(
"scale_inv", "tensor(uint8)", "Scale factor for dequantization"
),
],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")],
)
TRT_MXFP8DequantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_MXFP8DequantizeLinear", op_schema=schema
)
# ONNX LayerNorm
@torch.library.custom_op("tex::layernorm", mutates_args=[])
def onnx_layernorm_op(
inp: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float
) -> torch.Tensor:
"""ONNX LayerNorm used for inference."""
model = tex.LayerNorm(inp.shape[1], eps=eps)
model.weight.data = weight
model.bias.data = bias
return model(inp)
@onnx_layernorm_op.register_fake
def _(inp, *_):
"""Fake ONNX LayerNorm used for inference."""
return inp
def onnx_layernorm_symbolic(
inp: onnxscript.onnx_types.TensorType,
weight: onnxscript.onnx_types.TensorType,
bias: onnxscript.onnx_types.TensorType,
eps: float,
) -> onnxscript.onnx_types.TensorType:
"""Symbolic ONNX LayerNorm used for inference."""
return op.LayerNormalization(inp, weight, bias, epsilon=eps)
# onnx layernorm helper function - handles layernorm with quantization
def onnx_layernorm(
inp: torch.Tensor,
layer_norm_weight: torch.Tensor,
layer_norm_bias: torch.Tensor,
eps: float,
normalization: str,
zero_centered_gamma: bool,
output_dtype: torch.dtype,
return_layernorm_output: bool,
input_quantizer,
) -> torch.Tensor:
"""ONNX LayerNorm used for inference."""
ln_weight = layer_norm_weight if not zero_centered_gamma else layer_norm_weight + 1
ln_weight = ln_weight.to(inp.dtype).to(torch.float32)
inp = inp.to(torch.float32)
layer_norm_bias = (
layer_norm_bias.to(output_dtype).to(torch.float32) if layer_norm_bias is not None else None
)
if normalization == "RMSNorm":
ln_out = torch.nn.functional.rms_norm(inp, inp.shape[-1:], ln_weight, eps)
else:
ln_out = torch.nn.functional.layer_norm(
inp, inp.shape[-1:], ln_weight, layer_norm_bias, eps
)
ln_out_return = ln_out
if input_quantizer is not None:
if return_layernorm_output:
# In case of return_layernorm_output, layernorm is not fused with fp8 cast,
# so we cast to input_dtype and then perform cast to fp8 if needed
ln_out = ln_out.to(output_dtype).to(torch.float32)
ln_out_return = ln_out
elif isinstance(input_quantizer, MXFP8Quantizer):
# layernorm + mxfp8 quantizer behaves differently
ln_out = ln_out.to(output_dtype).to(torch.float32)
ln_out_quantized = input_quantizer.onnx_quantize(ln_out)
ln_out = input_quantizer.onnx_dequantize(ln_out_quantized)
ln_out = ln_out.to(output_dtype)
return ln_out, ln_out_return
# utility functions
def onnx_attention_mask_func(
attention_scores: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""Get attention mask without inp"""
assert is_in_onnx_export_mode()
return attention_scores.masked_fill(attention_mask, -10000.0)
# This translation table should be passed to torch.onnx.export function
# using the custom_translation_table=te_translation_table option.
te_translation_table = {
torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic,
torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic,
torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic,
torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic,
torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic,
torch.ops.tex.layernorm.default: onnx_layernorm_symbolic,
}
......@@ -5,7 +5,7 @@
"""Helper functions used in fusible operations."""
from __future__ import annotations
from typing import Any, Iterable, Optional
from typing import Optional
import torch
......@@ -13,84 +13,24 @@ from transformer_engine_torch import FP8TensorMeta
from .. import torch_version
from ..fp8 import FP8GlobalStateManager
from ..tensor.float8_tensor import Float8Tensor
from ..utils import (
canonicalize_device,
canonicalize_dtype,
devices_match,
)
def is_float8_tensor(tensor: Any) -> bool:
"""Check if object is a `Float8Tensor`"""
return isinstance(tensor, Float8Tensor)
def convert_tensor(
tensor: torch.Tensor | Float8Tensor,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
memory_format: torch.memory_format = torch.preserve_format,
) -> torch.Tensor | Float8Tensor:
"""Convert tensor attributes, keeping same data if possible"""
# Default kwargs
if device is None:
device = tensor.device
device = canonicalize_device(device)
if dtype is None:
dtype = tensor.dtype
dtype = canonicalize_dtype(dtype)
# Make sure output is detached from autograd graph
tensor = tensor.detach()
# Return immediately if tensor already has desired attributes
if devices_match(device, tensor.device) and dtype == tensor.dtype:
if memory_format == torch.preserve_format or tensor.is_contiguous(
memory_format=memory_format
):
return tensor
from ..tensor.quantized_tensor import QuantizedTensorBase
from ..utils import canonicalize_dtype
# Convert FP8 tensor
if is_float8_tensor(tensor):
data = tensor._data
if not devices_match(device, data.device):
data = data.to(device=device)
if memory_format != torch.preserve_format and not data.is_contiguous(
memory_format=memory_format
):
# Note: torch.Tensor.to ignores memory_format kwarg (see
# https://github.com/pytorch/pytorch/issues/132020).
data = data.contiguous(memory_format=memory_format)
out = Float8Tensor.make_like(tensor, dtype=dtype)
out.data = data
return out
# Convert standard PyTorch tensor
tensor = tensor.to(device=device, dtype=dtype)
if memory_format != torch.preserve_format and not tensor.is_contiguous(
memory_format=memory_format
):
# Note: torch.Tensor.to ignores memory_format kwarg (see
# https://github.com/pytorch/pytorch/issues/132020).
tensor = tensor.contiguous(memory_format=memory_format)
return tensor
def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorBase) -> bool:
"""Check if tensor is a quantized tensor"""
return isinstance(tensor, QuantizedTensorBase)
def reshape(
tensor: torch.Tensor | Float8Tensor,
shape: Iterable[int],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor | Float8Tensor:
"""Reshape tensor, keeping same data if possible"""
tensor = convert_tensor(
tensor,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
return tensor.reshape(*shape)
def maybe_dequantize(
tensor: torch.Tensor | QuantizedTensorBase, dtype: torch.dtype | None = None
) -> torch.Tensor:
"""Dequantize tensor to given dtype or just convert if not a quantized tensor"""
if is_quantized_tensor(tensor):
return tensor.dequantize(dtype=dtype)
if dtype is not None and tensor.dtype != dtype:
return tensor.to(dtype)
return tensor
def maybe_autocast_dtype(
......
......@@ -12,11 +12,10 @@ import torch
import transformer_engine_torch as tex
from ...fp8 import FP8GlobalStateManager
from ...tensor import QuantizedTensor
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer
from ...utils import clear_tensor_data, devices_match
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import reshape
from .._common import maybe_dequantize
class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
......@@ -72,8 +71,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Compute dtype
......@@ -86,35 +85,16 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
raise RuntimeError(f"Unsupported dtype ({dtype})")
# Check input tensor
x = input_
if isinstance(x, QuantizedTensor):
x = x.dequantize()
if x.device.type != "cuda":
x = x.cuda()
if x.dtype != dtype:
x = x.to(dtype=dtype)
if not x.is_contiguous():
x = x.contiguous()
x = maybe_dequantize(input_.contiguous(), dtype)
# Check if quantized compute is enabled
quantized_compute_enabled = FP8GlobalStateManager.is_fp8_enabled()
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
quantizer = None
if (
quantized_compute_enabled
and next_op is not None
and next_op.num_quantizers("forward") > 0
):
quantizer = next_op.get_quantizer("forward", 0)
if with_quantized_compute:
quantizer = next_op_input_quantizer
# Launch kernel
y = self._activation_forward_impl(
reshape(x, (-1, x.size(-1))),
quantizer,
)
# Check output tensor
if y.dim() != x.dim():
y = y.reshape(list(x.shape[:-1]) + [-1])
y = self._activation_forward_impl(x, quantizer)
# Quantize input to FP8 before caching if needed
if self.cache_quantized_input:
......@@ -123,10 +103,10 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
x = input_quantizer(x)
# Save state for backward pass
ctx.save_for_backward(x.detach())
ctx.quantized_compute_enabled = quantized_compute_enabled
ctx.save_for_backward(x)
ctx.with_quantized_compute = with_quantized_compute
ctx.dtype = dtype
ctx.prev_op = prev_op
ctx.prev_op_grad_input_quantizer = prev_op_grad_input_quantizer
return y
......@@ -140,44 +120,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
(x,) = ctx.saved_tensors
# Check input tensor
if isinstance(x, QuantizedTensor):
x = x.dequantize(dtype=ctx.dtype)
elif x.dtype != ctx.dtype:
x = x.to(dtype=ctx.dtype)
if not x.is_contiguous():
x = x.contiguous()
x = maybe_dequantize(x.contiguous(), ctx.dtype)
# Check grad output tensor
dy = grad_output
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize(dtype=ctx.dtype)
if not devices_match(dy.device, x.device) or dy.dtype != x.dtype:
dy = dy.to(device=x.device, dtype=x.dtype)
if not dy.is_contiguous():
dy = dy.contiguous()
dy = maybe_dequantize(grad_output.contiguous(), x.dtype)
# Check if quantized compute is enabled
quantizer = None
if (
ctx.quantized_compute_enabled
and ctx.prev_op is not None
and ctx.prev_op.num_quantizers("backward") > 0
):
quantizer = ctx.prev_op.get_quantizer("backward", 0)
if ctx.with_quantized_compute:
quantizer = ctx.prev_op_grad_input_quantizer
# Launch kernel
dx = self._activation_backward_impl(
reshape(dy, (-1, dy.size(-1))),
reshape(x, (-1, x.size(-1))),
quantizer,
)
# Check grad input tensor
if dx.size() != x.size():
dx = dx.reshape(x.size())
dx = self._activation_backward_impl(dy, x, quantizer)
# Clear input tensor if possible
if ctx.prev_op is not None:
clear_tensor_data(x)
return dx, ()
......
......@@ -15,6 +15,8 @@ from transformer_engine.pytorch.ops.op import (
OperationContext,
)
from transformer_engine.pytorch.tensor import Quantizer
class AddInPlace(BasicOperation):
"""Add in-place
......@@ -57,8 +59,8 @@ class AddInPlace(BasicOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
output = basic_op_extra_inputs[0][0].detach()
......@@ -76,4 +78,4 @@ class AddInPlace(BasicOperation):
Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]],
]:
return grad_output, [], [(grad_output,)]
return grad_output, [()], [(grad_output,)]
......@@ -10,8 +10,9 @@ from typing import Optional
import torch
from ...distributed import gather_along_first_dim
from ...tensor import QuantizedTensor
from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext
from ...tensor import Quantizer
class AllGather(BasicOperation):
......@@ -39,8 +40,8 @@ class AllGather(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
out: torch.Tensor
if self.process_group_size == 1:
......@@ -71,10 +72,7 @@ class AllGather(BasicOperation):
input_dims[0] //= self.process_group_size
# Check output gradient tensor
dy = grad_output
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
dy = dy.contiguous()
dy = maybe_dequantize(grad_output.contiguous())
# Perform reduce-scatter
dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device)
......
......@@ -9,8 +9,9 @@ from typing import Optional
import torch
from ...tensor import QuantizedTensor
from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext
from ...tensor import Quantizer
class AllReduce(BasicOperation):
......@@ -41,8 +42,8 @@ class AllReduce(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Trivial case
......@@ -50,10 +51,7 @@ class AllReduce(BasicOperation):
return input_
# Perform all-reduce
x = input_
if isinstance(x, QuantizedTensor):
x = x.dequantize()
x = x.contiguous()
x = maybe_dequantize(input_.contiguous())
torch.distributed.all_reduce(x, group=self.process_group)
return x
......
......@@ -19,20 +19,21 @@ from ...distributed import (
gather_along_first_dim,
reduce_scatter_along_first_dim,
)
from ...fp8 import FP8GlobalStateManager
from ...fp8 import FP8GlobalStateManager, Recipe
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer, QuantizedTensor
from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext
from .._common import (
from .._common import maybe_dequantize, is_quantized_tensor
from ...utils import (
canonicalize_device,
canonicalize_dtype,
clear_tensor_data,
devices_match,
)
from ...utils import clear_tensor_data
def _wait_async(handle: Optional[Any]) -> None:
......@@ -271,7 +272,7 @@ class BasicLinear(BasicOperation):
device = canonicalize_device(None)
# Allocate buffer if needed
if isinstance(weight, QuantizedTensor):
if is_quantized_tensor(weight):
weight = torch.empty(
weight.size(),
dtype=weight.dtype,
......@@ -302,8 +303,12 @@ class BasicLinear(BasicOperation):
weight = torch.nn.Parameter(weight)
self.weight = weight
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
def pre_first_forward(
self,
*,
recipe: Optional[Recipe],
) -> None:
super().pre_first_forward(recipe=recipe)
# Initialize weights if needed
weight = self.weight
......@@ -312,20 +317,17 @@ class BasicLinear(BasicOperation):
weight = self.weight
# Configure quantizers
if FP8GlobalStateManager.is_fp8_enabled():
if recipe is not None:
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
grad_output_quantizer = self.get_quantizer("backward", 0)
# Specify required tensor formats
is_grad_enabled = torch.is_grad_enabled()
weight_requires_grad = is_grad_enabled and weight.requires_grad
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
input_quantizer.internal = True
weight_quantizer.internal = True
grad_output_quantizer.internal = True
# Recipe-specific configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
if any(
not isinstance(q, Float8CurrentScalingQuantizer)
......@@ -390,7 +392,7 @@ class BasicLinear(BasicOperation):
Bias tensor
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
dtype: torch.dtype, default = infer from out or weight
Tensor datatype
out: torch.Tensor, optional
Output tensor
......@@ -437,8 +439,14 @@ class BasicLinear(BasicOperation):
# Check datatype
if dtype is None:
dtype = weight.dtype if out is None else out.dtype
dtype = canonicalize_dtype(dtype)
if out is not None and isinstance(out, torch.Tensor):
dtype = out.dtype
elif weight is not None and isinstance(out, torch.Tensor):
dtype = weight.dtype
else:
raise ValueError(
"Could not infer dtype from weight nor out and dtype was not provided"
)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
if out is not None and out.dtype != dtype:
......@@ -462,14 +470,12 @@ class BasicLinear(BasicOperation):
quantizer=input_quantizer,
)
else:
if not isinstance(x_local, QuantizedTensor):
if not is_quantized_tensor(x_local):
x_local = input_quantizer(x_local)
x = x_local
else:
if isinstance(x_local, QuantizedTensor):
x_local = x_local.dequantize()
if x_local.dtype != dtype:
x_local = x_local.to(dtype=dtype)
x_local = maybe_dequantize(x_local, dtype)
if with_x_all_gather:
x, x_async = gather_along_first_dim(
x_local,
......@@ -481,16 +487,13 @@ class BasicLinear(BasicOperation):
# Check weight tensor
w = weight
w_is_quantized = isinstance(w, QuantizedTensor)
if with_quantized_compute and not w_is_quantized:
if not with_quantized_compute:
w = maybe_dequantize(w, dtype)
elif with_quantized_compute and not is_quantized_tensor(w):
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
if not with_quantized_compute and w.dtype != dtype:
w = w.to(dtype=dtype)
# Check output tensor
y = out
......@@ -499,7 +502,7 @@ class BasicLinear(BasicOperation):
output_quantizer = None
if tensor_parallel_mode == "row":
output_quantizer = None
elif isinstance(y, QuantizedTensor):
elif is_quantized_tensor(y):
if not with_quantized_compute:
raise ValueError("Output tensor is quantized, but quantized compute is not enabled")
if tensor_parallel_mode == "row":
......@@ -564,18 +567,14 @@ class BasicLinear(BasicOperation):
# Prepare weight tensor for backward pass
if input_requires_grad:
if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensor):
if w is not weight and with_quantized_compute and is_quantized_tensor(w):
w.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
w = None
# Prepare input tensor for backward pass
if weight_requires_grad:
if x_local is input:
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local = x_local.detach()
if with_quantized_compute and isinstance(x_local, QuantizedTensor):
if with_quantized_compute and is_quantized_tensor(x_local):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
......@@ -668,9 +667,9 @@ class BasicLinear(BasicOperation):
# Check datatype
if dtype is None:
if weight is not None:
if isinstance(weight, torch.Tensor):
dtype = weight.dtype
else:
elif isinstance(grad_output, torch.Tensor):
dtype = grad_output.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
......@@ -696,14 +695,17 @@ class BasicLinear(BasicOperation):
quantizer=grad_output_quantizer,
)
else:
if not isinstance(dy_local, QuantizedTensor):
if not is_quantized_tensor(dy_local):
dy_local = grad_output_quantizer(dy_local)
else:
dy_local.update_usage(
rowwise_usage=input_requires_grad,
columnwise_usage=weight_requires_grad,
)
dy = dy_local
else:
if isinstance(dy_local, QuantizedTensor):
dy_local = dy_local.dequantize()
if dy_local.dtype != dtype:
dy_local = dy_local.to(dtype=dtype)
dy_local = maybe_dequantize(dy_local, dtype)
if with_dy_all_gather:
dy, dy_async = gather_along_first_dim(
dy_local,
......@@ -733,16 +735,14 @@ class BasicLinear(BasicOperation):
quantizer=input_quantizer,
)
else:
if isinstance(x_local, QuantizedTensor):
if is_quantized_tensor(x_local):
x_local.update_usage(columnwise_usage=True)
else:
x_local = input_quantizer(x_local)
x = x_local
else:
if isinstance(x_local, QuantizedTensor):
x_local = x_local.dequantize()
if x_local.dtype != dtype:
x_local = x_local.to(dtype=dtype)
x_local = maybe_dequantize(x_local, dtype)
if with_x_all_gather:
x, x_async = gather_along_first_dim(
x_local,
......@@ -761,9 +761,8 @@ class BasicLinear(BasicOperation):
if weight is None:
raise ValueError("Weight tensor is required to compute input grad")
w = weight
w_is_quantized = isinstance(w, QuantizedTensor)
if with_quantized_compute:
if w_is_quantized:
if is_quantized_tensor(w):
w.update_usage(columnwise_usage=True)
else:
if weight_quantizer is None:
......@@ -771,10 +770,7 @@ class BasicLinear(BasicOperation):
weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w)
else:
if w_is_quantized:
w = w.dequantize(dtype=dtype)
elif w.dtype != dtype:
w = w.to(dtype=dtype)
w = maybe_dequantize(w, dtype)
# Synchronize tensor-parallel communication
_wait_async(dy_async)
......@@ -787,7 +783,7 @@ class BasicLinear(BasicOperation):
grad_input_quantizer = None
if tensor_parallel_mode == "column":
grad_input_quantizer = None
elif isinstance(dx, QuantizedTensor):
elif is_quantized_tensor(dx):
if not with_quantized_compute:
raise ValueError(
"Grad input tensor is quantized, but quantized compute is not enabled"
......@@ -898,12 +894,12 @@ class BasicLinear(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Check which grads are required
input_requires_grad = ctx.requires_grad and input_.requires_grad
input_requires_grad = ctx.requires_grad
weight_requires_grad = ctx.requires_grad and self.weight.requires_grad
# FP8 metadata
......@@ -918,11 +914,9 @@ class BasicLinear(BasicOperation):
# Get quantizers
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
if next_op is not None and next_op.num_quantizers("forward") > 0:
output_quantizer = next_op.get_quantizer("forward", 0)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = self.get_quantizer("backward", 0)
if prev_op is not None and prev_op.num_quantizers("backward") > 0:
grad_input_quantizer = prev_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Configure quantizers
# Note: We cache the quantized input for backward pass,
......@@ -931,9 +925,10 @@ class BasicLinear(BasicOperation):
weight_quantizer.set_usage(rowwise=True, columnwise=False)
# Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = self.weight.dtype
# Linear forward
output, x_local, w = BasicLinear._functional_forward(
......@@ -961,7 +956,6 @@ class BasicLinear(BasicOperation):
ctx.dtype = dtype
ctx.input_requires_grad = input_requires_grad
ctx.weight_requires_grad = weight_requires_grad
ctx.has_prev_op = prev_op is not None
return output
......@@ -978,6 +972,9 @@ class BasicLinear(BasicOperation):
accumulate_into_main_grad = self._accumulate_into_main_grad
grad_weight = None
if ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(self.weight, "__fsdp_param__"):
self.weight.main_grad = self.weight.get_main_grad()
if not hasattr(self.weight, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
......@@ -1009,7 +1006,6 @@ class BasicLinear(BasicOperation):
)
# Clear input tensor if possible
if ctx.has_prev_op:
clear_tensor_data(x_local)
if accumulate_into_main_grad:
......
......@@ -9,14 +9,17 @@ from typing import Optional
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import (
from ...utils import (
canonicalize_device,
canonicalize_dtype,
)
from ...fp8 import FP8GlobalStateManager
from ...tensor import Quantizer
class Bias(BasicOperation):
......@@ -111,8 +114,8 @@ class Bias(BasicOperation):
bias = torch.nn.Parameter(bias)
self.bias = bias
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
def pre_first_forward(self, *args, **kwargs) -> None:
super().pre_first_forward(*args, **kwargs)
if self.bias.device.type == "meta":
self.reset_parameters()
......@@ -120,11 +123,25 @@ class Bias(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
x = input_
b = self.bias.reshape([1] * (x.dim() - 1) + [self.local_size])
b = self.bias.view([1] * (x.dim() - 1) + [self.local_size])
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if previous op quantizes its output's gradient
grad_input_quantizer = None
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
grad_input_quantizer = prev_op_grad_input_quantizer
if requires_grad:
ctx.with_quantized_compute = with_quantized_compute
ctx.grad_input_quantizer = grad_input_quantizer
return x + b
def op_backward(
......@@ -134,6 +151,10 @@ class Bias(BasicOperation):
) -> tuple[torch.Tensor, tuple[()]]:
dy = grad_output
if dy.dim() > 1:
quantizer = ctx.grad_input_quantizer
if ctx.with_quantized_compute and quantizer is not None:
db, dy = tex.bgrad_quantize(dy, quantizer)
else:
db = dy.sum(tuple(range(dy.dim() - 1)))
else:
db = dy
......
......@@ -13,6 +13,7 @@ from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer
class Identity(BasicOperation):
......@@ -22,8 +23,8 @@ class Identity(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
return input_
......
......@@ -9,8 +9,8 @@ from typing import Optional
import torch
from ...tensor import QuantizedTensor
from ...utils import clear_tensor_data
from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext
from ...jit import (
l2normalization_fused,
......@@ -19,6 +19,7 @@ from ...jit import (
set_jit_fusion_options,
warmup_jit_l2normalization_all_dtypes,
)
from ...tensor import Quantizer
class L2Normalization(BasicOperation):
......@@ -73,14 +74,11 @@ class L2Normalization(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Use input directly - torch.compile can handle multi-dimensional tensors
x = input_
if isinstance(x, QuantizedTensor):
x = x.dequantize()
x = maybe_dequantize(input_)
# Check if backward pass is needed
requires_grad = ctx.requires_grad
......@@ -98,7 +96,6 @@ class L2Normalization(BasicOperation):
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, rsqrt_norm)
ctx.has_prev_op = prev_op is not None
return y
......@@ -111,16 +108,12 @@ class L2Normalization(BasicOperation):
# Saved tensors from forward pass
x, rsqrt_norm = ctx.saved_tensors
dy = grad_output
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
dy = maybe_dequantize(grad_output)
# Compute L2 norm backward pass using fused implementation
dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps)
# Clear saved tensors if possible
if ctx.has_prev_op:
clear_tensor_data(x)
clear_tensor_data(rsqrt_norm)
......
......@@ -15,7 +15,6 @@ import torch
from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...fp8 import FP8GlobalStateManager
from ...constants import TE_DType
from ...tensor import QuantizedTensor
from ...utils import (
canonicalize_device,
canonicalize_dtype,
......@@ -23,7 +22,9 @@ from ...utils import (
devices_match,
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape
from .._common import maybe_autocast_dtype, maybe_dequantize
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
class LayerNorm(BasicOperation):
......@@ -167,8 +168,8 @@ class LayerNorm(BasicOperation):
self.weight = weight
self.bias = bias
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
def pre_first_forward(self, *args, **kwargs) -> None:
super().pre_first_forward(*args, **kwargs)
if self.weight.device.type == "meta" or self.bias.device.type == "meta":
self.reset_parameters()
......@@ -176,9 +177,11 @@ class LayerNorm(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
if is_in_onnx_export_mode():
return self.op_onnx_forward(input_)
# Check tensor dims
weight = self.weight
......@@ -192,31 +195,19 @@ class LayerNorm(BasicOperation):
# Check input tensors
inner_dim = math.prod(weight_dims)
device = weight.device
if device.type != "cuda":
device = canonicalize_device(None)
dtype = maybe_autocast_dtype(default_dtype=weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype)
if isinstance(x, QuantizedTensor):
x = x.dequantize()
if isinstance(w, QuantizedTensor):
w = w.dequantize()
if isinstance(b, QuantizedTensor):
b = b.dequantize()
x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim))
w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
b = maybe_dequantize(self.bias, dtype).view((inner_dim,))
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if output is quantized
output_quantizer = None
if (
FP8GlobalStateManager.is_fp8_enabled()
and next_op is not None
and next_op.num_quantizers("forward") > 0
):
output_quantizer = next_op.get_quantizer("forward", 0)
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
output_quantizer = next_op_input_quantizer
# Compute layer norm
sm_margin = self._sm_margins["forward" if requires_grad else "inference"]
......@@ -235,12 +226,10 @@ class LayerNorm(BasicOperation):
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, means, rstdevs)
ctx.device = device
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None
# Reshape output tensor
out = reshape(y, input_dims)
out = y.view(input_dims)
return out
def op_backward(
......@@ -257,14 +246,9 @@ class LayerNorm(BasicOperation):
inner_dim = math.prod(weight_dims)
# Check input tensors
device = ctx.device
dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(w, QuantizedTensor):
w = w.dequantize()
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size())
w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
# Compute layer norm backward pass
dx, dw, db = layernorm_bwd(
......@@ -278,13 +262,22 @@ class LayerNorm(BasicOperation):
)
# Clear saved tensors if possible
if ctx.has_prev_op:
clear_tensor_data(x)
clear_tensor_data(means)
clear_tensor_data(rstdevs)
# Reshape results
grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, weight_dims)
grad_bias = reshape(db, weight_dims)
grad_input = dx.view(grad_output.size())
grad_weight = dw.view(weight_dims)
grad_bias = db.view(weight_dims)
return grad_input, (grad_weight, grad_bias)
def op_onnx_forward(
self,
input_: torch.Tensor,
) -> torch.Tensor:
"""Every operand in this function has a defined ONNX translation."""
weight = self.weight + 1 if self.zero_centered_gamma else self.weight
return torch.nn.functional.layer_norm(
input_, input_.shape[-1:], weight, self.bias, self.eps
)
......@@ -14,6 +14,7 @@ from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer
class MakeExtraOutput(BasicOperation):
......@@ -58,8 +59,8 @@ class MakeExtraOutput(BasicOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
return input_, [(input_,)]
......@@ -77,4 +78,4 @@ class MakeExtraOutput(BasicOperation):
]:
grad_input = basic_op_grad_extra_outputs[0][0]
grad_input += grad_output
return grad_input, [], [()]
return grad_input, [()], [()]
......@@ -10,8 +10,9 @@ from typing import Optional
import torch
from ...fp8 import FP8GlobalStateManager
from ...tensor import QuantizedTensor
from .._common import is_quantized_tensor
from ..op import BasicOperation, OperationContext
from ...tensor import Quantizer
class Quantize(BasicOperation):
......@@ -49,8 +50,8 @@ class Quantize(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Check if FP8 is enabled
......@@ -60,7 +61,7 @@ class Quantize(BasicOperation):
# Quantize if needed
out = input_
if quantize_forward and not isinstance(out, QuantizedTensor):
if quantize_forward and not is_quantized_tensor(out):
out = self.get_quantizer("forward", 0)(out)
ctx.quantize_backward = quantize_backward
......@@ -72,6 +73,6 @@ class Quantize(BasicOperation):
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
grad_input = grad_output
if ctx.quantize_backward and not isinstance(grad_input, QuantizedTensor):
if ctx.quantize_backward and not is_quantized_tensor(grad_input):
grad_input = self.get_quantizer("backward", 0)(grad_input)
return grad_input, ()
......@@ -10,8 +10,9 @@ from typing import Optional
import torch
from ...distributed import gather_along_first_dim
from ...tensor import QuantizedTensor
from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext
from ...tensor import Quantizer
class ReduceScatter(BasicOperation):
......@@ -39,8 +40,8 @@ class ReduceScatter(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Trivial case
......@@ -59,10 +60,7 @@ class ReduceScatter(BasicOperation):
output_dims[0] //= self.process_group_size
# Check input tensor
x = input_
if isinstance(x, QuantizedTensor):
x = x.dequantize()
x = x.contiguous()
x = maybe_dequantize(input_.contiguous())
# Perform reduce-scatter
y = torch.empty(output_dims, dtype=x.dtype, device=x.device)
......
......@@ -14,6 +14,7 @@ from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer
class Reshape(BasicOperation):
......@@ -37,8 +38,8 @@ class Reshape(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
ctx.input_shape = input_.size()
return input_.reshape(*self._shape)
......
......@@ -14,7 +14,6 @@ import torch
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...fp8 import FP8GlobalStateManager
from ...tensor import QuantizedTensor
from ...constants import TE_DType
from ...utils import (
canonicalize_device,
......@@ -23,7 +22,9 @@ from ...utils import (
devices_match,
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape
from .._common import maybe_autocast_dtype, maybe_dequantize
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
class RMSNorm(BasicOperation):
......@@ -150,8 +151,8 @@ class RMSNorm(BasicOperation):
weight = torch.nn.Parameter(weight)
self.weight = weight
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
def pre_first_forward(self, *args, **kwargs) -> None:
super().pre_first_forward(*args, **kwargs)
if self.weight.device.type == "meta":
self.reset_parameters()
......@@ -159,9 +160,11 @@ class RMSNorm(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
if is_in_onnx_export_mode():
return self.op_onnx_forward(input_)
# Check tensor dims
weight = self.weight
......@@ -175,28 +178,18 @@ class RMSNorm(BasicOperation):
# Check input tensors
inner_dim = math.prod(weight_dims)
device = weight.device
if device.type != "cuda":
device = canonicalize_device(None)
dtype = maybe_autocast_dtype(default_dtype=weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(x, QuantizedTensor):
x = x.dequantize()
if isinstance(w, QuantizedTensor):
w = w.dequantize()
x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim))
w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if output is quantized
output_quantizer = None
if (
FP8GlobalStateManager.is_fp8_enabled()
and next_op is not None
and next_op.num_quantizers("forward") > 0
):
output_quantizer = next_op.get_quantizer("forward", 0)
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
output_quantizer = next_op_input_quantizer
# Compute RMSNorm
sm_margin = self._sm_margins["forward" if requires_grad else "inference"]
......@@ -214,12 +207,10 @@ class RMSNorm(BasicOperation):
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, rstdevs)
ctx.device = device
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None
# Reshape output tensor
out = reshape(y, input_dims)
out = y.view(input_dims)
return out
def op_backward(
......@@ -236,14 +227,9 @@ class RMSNorm(BasicOperation):
inner_dim = math.prod(weight_dims)
# Check input tensors
device = ctx.device
dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(w, QuantizedTensor):
w = w.dequantize()
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size())
w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
# Compute RMSNorm backward pass
dx, dw = rmsnorm_bwd(
......@@ -256,11 +242,18 @@ class RMSNorm(BasicOperation):
)
# Clear saved tensors if possible
if ctx.has_prev_op:
clear_tensor_data(x)
clear_tensor_data(rstdevs)
# Reshape results
grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, weight_dims)
grad_input = dx.view(grad_output.size())
grad_weight = dw.view(weight_dims)
return grad_input, (grad_weight,)
def op_onnx_forward(
self,
input_: torch.Tensor,
) -> torch.Tensor:
"""Every operand in this function has a defined ONNX translation."""
weight = self.weight + 1 if self.zero_centered_gamma else self.weight
return torch.nn.functional.rms_norm(input_, input_.shape[-1:], weight, self.eps)
......@@ -4,6 +4,10 @@
"""Compound tensor operation supported by the operation fuser."""
from .backward_bias_activation import (
BackwardBiasActivation,
fuse_backward_bias_activation,
)
from .backward_linear_add import (
BackwardLinearAdd,
fuse_backward_linear_add,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused backward dbias + dact + quantize."""
from __future__ import annotations
from typing import Optional
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import Recipe
from transformer_engine.pytorch.ops.basic import Bias
from transformer_engine.pytorch.ops.basic.activation import (
_ActivationOperation,
GELU,
ReLU,
)
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data
from .._common import maybe_dequantize
_fused_activations = {GELU: tex.dbias_dgelu, ReLU: tex.dbias_drelu}
_fusible_activations = tuple(_fused_activations.keys())
class BackwardBiasActivation(FusedOperation):
"""Fused backward dbias + dact + quantize
Uses the next operation's input quantizer.
"""
def __init__(self, *, bias: Bias, activation: _ActivationOperation):
super().__init__((bias, activation))
self._fused_function = _fused_activations[type(activation)]
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
list[tuple[Optional[torch.Tensor], ...]],
list[tuple[()]],
]:
# Get basic operation contexts
activation_op_ctx = basic_op_ctxs[0]
bias_op_ctx = basic_op_ctxs[1]
# Saved tensors from forward pass
(act_input,) = activation_op_ctx.saved_tensors
# Check activation input tensor
act_input = maybe_dequantize(act_input.contiguous(), activation_op_ctx.dtype)
# Check grad output tensor
dy = maybe_dequantize(grad_output.contiguous(), act_input.dtype)
# Get previous op quantizer
if not bias_op_ctx.with_quantized_compute:
raise RuntimeError(
"BackwardBiasActivation requires quantized compute, "
"but Bias context has it disabled"
)
quantizer = bias_op_ctx.grad_input_quantizer
if quantizer is None:
raise RuntimeError(
"BackwardBiasActivation requires previous op's grad output quantizer, "
"but Bias context has no quantizer"
)
# Launch kernel
db, dx = self._fused_function(dy, act_input, quantizer)
# Clear activation input tensor
clear_tensor_data(act_input)
return dx, [(), (db,)], [(), ()]
def fuse_backward_bias_activation(
ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dbias + dact + quantize
Parameters
----------
ops: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
recipe: Recipe, optional
Used quantization recipe
Returns
-------
ops: list of tuples
Updated backward pass operations
"""
# Check if recipe supports bias activation fusion
if recipe is None or not (recipe.delayed() or recipe.mxfp8()):
return ops
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 3:
out.extend(window)
# Check if first op is a supported activation
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, _fusible_activations):
continue
# Check if second op is bias
op, _ = ops[0]
if not isinstance(op, Bias):
continue
# Check if third op has a grad input quantizer
op, _ = ops[1]
if not op.num_quantizers("backward") > 0:
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardBiasActivation(
activation=window[0][0],
bias=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment