Commit 27ddce40 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents d262ef4c 5b3092a0
......@@ -94,39 +94,45 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
# bf16 (recipe is None):
return {
"gelu": (tex.gelu, tex.dgelu, None),
"relu": (tex.relu, tex.drelu, None),
"geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"srelu": (tex.srelu, tex.dsrelu, None),
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
}
if recipe.delayed() or recipe.mxfp8():
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
return {
"gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu),
"relu": (tex.relu, tex.drelu, tex.dbias_drelu),
"geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, tex.dbias_drelu),
"reglu": (tex.reglu, tex.dreglu, None),
"srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu),
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, tex.dbias_dsilu),
"swiglu": (tex.swiglu, tex.dswiglu, None),
}
# no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling: []
if recipe.float8_current_scaling() or recipe.float8_block_scaling():
return {
"gelu": (tex.gelu, tex.dgelu, None),
"relu": (tex.relu, tex.drelu, None),
"geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"srelu": (tex.srelu, tex.dsrelu, None),
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
}
raise NotImplementedError(f"Unhandled recipe type {recipe}")
......@@ -308,7 +314,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag:
# Copy into Userbuffers buffer
ub_obj_lnout = get_ub("fc1_fprop")
ub_obj_lnout = get_ub("fc1_fprop", fp8)
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_lnout,
ln_out,
......@@ -446,20 +452,25 @@ class _LayerNormMLP(torch.autograd.Function):
act_out = activation_func(fc1_out, None)
act_out = tex.quantize(act_out, fc2_input_quantizer)
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)
if fp8_calibration:
act_out = activation_func(fc1_out, None)
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)
if not is_grad_enabled:
clear_tensor_data(fc1_out)
if fp8_calibration:
fc2_input_quantizer.calibrate(act_out)
fc2_weight_quantizer.calibrate(fc2_weight)
if not fp8 and fp8_calibration:
if fc2_input_quantizer is not None:
fc2_input_quantizer.calibrate(act_out)
if fc2_weight_quantizer is not None:
fc2_weight_quantizer.calibrate(fc2_weight)
# Configure Userbuffers reduce-scatter if needed
ub_obj_fc2out = None
reduce_scatter_out = None
if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop")
ub_obj_fc2out = get_ub("fc2_fprop", fp8)
dim_size = list(act_out.size())
dim_size[0] //= tp_world_size
dim_size[-1] = fc2_weight.size(0)
......@@ -741,7 +752,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Cast to expected dtype and perform tensor-parallel communication
ub_obj_fc2_dgrad = None
if ctx.ub_overlap_ag:
ub_obj_fc2_dgrad = get_ub("fc2_dgrad")
ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8)
ctx.ub_obj_gradout = ub_obj_fc2_dgrad
(
grad_output,
......@@ -765,7 +776,7 @@ class _LayerNormMLP(torch.autograd.Function):
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
if ctx.ub_bulk_dgrad:
ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8)
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_fc1_dgrad,
ln_out,
......@@ -870,7 +881,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc2_dgrad.get_communication_stream()
)
ub_obj_fc2_wgrad = get_ub("fc2_wgrad")
ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8)
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -1045,16 +1056,16 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]]
if ctx.ub_overlap_rs_dgrad:
# Overlap DGRAD+RS
ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8)
ub_type_fc1_dgrad = tex.CommOverlapType.RS
else:
if ctx.ub_bulk_dgrad:
# Overlap ln_out all-gather with DGRAD compute
ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8)
ub_type_fc1_dgrad = tex.CommOverlapType.AG
if ctx.ub_bulk_wgrad:
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute
ub_obj_fc1_wgrad = get_ub("fc1_wgrad")
ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8)
ub_type_fc1_wgrad = tex.CommOverlapType.RS
# --------------------------------------------------
......@@ -1402,7 +1413,7 @@ class _LayerNormMLP(torch.autograd.Function):
class LayerNormMLP(TransformerEngineBaseModule):
r"""
Applies layer normalization on the input followed by the MLP module, consisting of
2 successive linear transformations, separated by the GeLU activation.
2 successive linear transformations, separated by the activation function.
Parameters
----------
......@@ -1418,7 +1429,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
type of normalization applied.
activation : str, default = 'gelu'
activation function used.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu', 'qgelu', 'srelu'.
Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'.
init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
......@@ -1559,7 +1571,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.gemm_gelu_fusion = (
bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0")))
and self.activation == "gelu"
and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm()))
and all(
("fc1_fprop", use_fp8) not in _ub_communicators
or not get_ub("fc1_fprop", use_fp8).is_atomic_gemm()
for use_fp8 in [False, True]
)
)
self.name = name
......@@ -1619,7 +1635,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_bias = None
# FC1 init
if self.activation in ["reglu", "geglu", "qgeglu", "swiglu"]:
if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu"]:
fc1_output_features = 2 * self.size_per_partition
else:
fc1_output_features = self.size_per_partition
......@@ -1777,7 +1793,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fp8_output = False
if self.ub_overlap_rs:
if get_ub("fc2_fprop").is_fp8_ubuf():
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
fp8_output = True
with torch.cuda.device(
......@@ -1915,7 +1931,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer,
) = [None] * 10
fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers()
if self.fp8:
if self.fp8 or self.fp8_calibration:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
......@@ -2001,14 +2017,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation_map = {
"gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"relu": torch.nn.functional.relu,
"geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh")
* x.chunk(2, -1)[1],
"qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"srelu": torch.nn.functional.softplus,
"relu": torch.nn.functional.relu,
"reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"srelu": lambda x: torch.nn.functional.relu(x) ** 2,
"sreglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) ** 2
* x.chunk(2, -1)[1],
"silu": torch.nn.functional.silu,
"swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
}
if self.activation not in activation_map:
raise ValueError(f"Unsupported activation in onnx export: {self.activation}")
......@@ -2129,7 +2148,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
if not self.fp8 and not self.fp8_calibration:
return [None, None]
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True
......@@ -2182,10 +2201,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.fc1_bias.grad is None:
self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype)
if not self.fuse_wgrad_accumulation:
if self.fc2_weight.grad is None:
self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype)
if self.fc1_weight.grad is None:
self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype)
self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype)
self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype)
del fc2_bias_grad_
del fc2_wgrad
del fc1_wgrad
......
......@@ -147,10 +147,10 @@ class _Linear(torch.autograd.Function):
ub_obj = None
ub_type = None
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.AG
# ------------------------------------------------------
......@@ -319,6 +319,13 @@ class _Linear(torch.autograd.Function):
# Finished forward GEMM...
# ------------------------------------------------------
# Deallocate GEMM input tensor if no longer needed
# TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically
# deallocated by GC. Manually deallocating is a temporary hack.
if with_input_all_gather_nccl:
clear_tensor_data(inputmat_total)
inputmat_total = None
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
......@@ -544,23 +551,23 @@ class _Linear(torch.autograd.Function):
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS
else:
if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ub_type_wgrad = tex.CommOverlapType.RS
# --------------------------------------------------
......@@ -793,7 +800,7 @@ class _Linear(torch.autograd.Function):
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -905,9 +912,16 @@ class _Linear(torch.autograd.Function):
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor if permitted
# Deallocate tensors if permitted
if ctx.owns_input:
# Input tensor is internal
clear_tensor_data(inputmat_total)
elif ctx.backward_input_needs_gather:
# Gathered input tensor is internal
clear_tensor_data(inputmat_total)
if ctx.parallel_mode == "row" and ctx.sequence_parallel:
# Gathered grad output tensor is internal
clear_tensor_data(grad_output)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad:
......@@ -1404,10 +1418,14 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch = False
if self.ub_overlap_rs_fprop:
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
if get_ub(
self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
fp8_output = True
if self.ub_overlap_rs_dgrad:
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
if get_ub(
self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
fp8_grad = True
with torch.cuda.device(
......@@ -1666,7 +1684,7 @@ class Linear(TransformerEngineBaseModule):
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
if not self.fp8 and not self.fp8_calibration:
return [None]
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
......
......@@ -112,7 +112,9 @@ schema = defs.OpSchema(
doc="TRT FP8 Quantize Linear used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"),
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for quantization"),
defs.OpSchema.FormalParameter(
"scale_inv", "tensor(float)", "Inverse scale factor for quantization"
),
],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")],
)
......@@ -126,11 +128,10 @@ TRT_FP8QuantizeLinear = onnxscript.values.Op(
@torch.library.custom_op("tex::fp8_dequantize", mutates_args=[])
def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor:
def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor:
"""Dequantize from Float8Tensor used for inference."""
scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device)
quantizer = Float8Quantizer(
scale_tensor, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3
1 / scale_inv, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3
)
quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32)
return quantizer_tensor.dequantize()
......@@ -143,10 +144,9 @@ def _(tensor: torch.Tensor, _) -> torch.Tensor:
def onnx_dequantize_fp8_symbolic(
tensor: onnxscript.onnx_types.TensorType, scale: float
tensor: onnxscript.onnx_types.TensorType, scale_inv: onnxscript.onnx_types.TensorType
) -> onnxscript.onnx_types.TensorType:
"""Symbolic dequantize from Float8Tensor used for inference."""
scale_inv = op.Constant(value_float=1 / scale)
return TRT_FP8DequantizeLinear(tensor, scale_inv)
......@@ -157,7 +157,9 @@ schema = defs.OpSchema(
doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"),
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for dequantization"),
defs.OpSchema.FormalParameter(
"scale_inv", "tensor(float)", "Inverse scale factor for dequantization"
),
],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")],
)
......@@ -166,6 +168,43 @@ TRT_FP8DequantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema
)
# ONNX FP8 Current Scaling Quantization
@torch.library.custom_op("tex::fp8_cs_quantize", mutates_args=[])
def onnx_cs_quantize_fp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantize to FP8 with current scaling; returns (uint8, scale_inv)."""
if tensor.dtype != torch.float32:
tensor = tensor.to(torch.float32)
amax = tensor.abs().max()
eps = torch.tensor(1e-12, dtype=torch.float32, device=tensor.device)
amax = torch.maximum(amax, eps)
fp8_max = torch.tensor(448, dtype=torch.float32, device=tensor.device)
scale = fp8_max / amax
q = torch.ops.tex.fp8_quantize(tensor, scale)
scale_inv = 1 / scale
return q, scale_inv
@onnx_cs_quantize_fp8_op.register_fake
def _(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device), torch.ones(
1, dtype=torch.float32, device=tensor.device
)
def onnx_quantize_fp8_cs_symbolic(
tensor: onnxscript.onnx_types.TensorType,
):
"""Symbolic quantize with current scaling; computes scale_inv from tensor."""
# scale_inv = 1 / max(abs(tensor))
amax = op.ReduceMax(op.Abs(tensor), keepdims=0)
eps = op.Constant(value_float=1.0e-12)
amax = op.Max(amax, eps)
scale_inv = op.Div(amax, op.Constant(value_float=448.0))
q = TRT_FP8QuantizeLinear(tensor, scale_inv)
return q, scale_inv
# ONNX MXFP8 Quantization
......@@ -194,12 +233,12 @@ def onnx_quantize_mxfp8_symbolic(
tensor: onnxscript.onnx_types.TensorType,
) -> Tuple[onnxscript.onnx_types.TensorType, onnxscript.onnx_types.TensorType]:
"""Symbolic quantize to MXFP8Tensor used for inference."""
tensor_out, scale_inv_out = TRT_MXFP8QuantizeLinear(tensor)
tensor_out, scale_inv_out = TRT_MXFP8DynamicQuantize(tensor)
return tensor_out, scale_inv_out
schema = defs.OpSchema(
name="TRT_MXFP8QuantizeLinear",
name="TRT_MXFP8DynamicQuantize",
domain="trt",
since_version=1,
doc="TRT MXFP8 Quantize Linear used for inference.",
......@@ -214,8 +253,8 @@ schema = defs.OpSchema(
],
)
TRT_MXFP8QuantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_MXFP8QuantizeLinear", op_schema=schema
TRT_MXFP8DynamicQuantize = onnxscript.values.Op(
opset=trt_opset, name="TRT_MXFP8DynamicQuantize", op_schema=schema
)
......@@ -356,6 +395,7 @@ te_translation_table = {
torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic,
torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic,
torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic,
torch.ops.tex.fp8_cs_quantize.default: onnx_quantize_fp8_cs_symbolic,
torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic,
torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic,
torch.ops.tex.layernorm.default: onnx_layernorm_symbolic,
......
......@@ -29,7 +29,9 @@ def maybe_dequantize(
if is_quantized_tensor(tensor):
return tensor.dequantize(dtype=dtype)
if dtype is not None and tensor.dtype != dtype:
return tensor.to(dtype)
tensor = tensor.to(dtype)
if not tensor.is_contiguous():
tensor = tensor.contiguous()
return tensor
......
......@@ -4,7 +4,7 @@
"""Single tensor operations supported by the operation fuser."""
from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU
from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU
from .add_extra_input import AddExtraInput
from .all_gather import AllGather
from .all_reduce import AllReduce
......
......@@ -11,11 +11,25 @@ from typing import Optional
import torch
import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize
__all__ = [
"GELU",
"GEGLU",
"QGELU",
"QGEGLU",
"ReLU",
"ReGLU",
"SReLU",
"SReGLU",
"SiLU",
"SwiGLU",
]
class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
r"""Apply activation function
......@@ -97,6 +111,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
ctx.save_for_backward(x)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
......@@ -147,37 +163,75 @@ class GELU(_ActivationOperation):
return tex.dgelu(*args, **kwargs)
class ReLU(_ActivationOperation):
r"""Rectified linear unit
class GEGLU(_ActivationOperation):
r"""Gaussian Error Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{ReLU}(x) = \max(x,0)
\text{GEGLU}(a,b) = \text{GELU}(a) * b
where
.. math::
\text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right)
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
See `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.relu(*args, **kwargs)
return tex.geglu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.drelu(*args, **kwargs)
return tex.dgeglu(*args, **kwargs)
class GEGLU(_ActivationOperation):
r"""Gaussian error gated linear unit
class QGELU(_ActivationOperation):
r"""Quick Gaussian Error Linear Unit
Quick GELU from `HuggingFace<https://github.com/huggingface/transformers/blob/3e93dd295b5343557a83bc07b0b2ea64c926f9b4/src/transformers/activations.py#L90>`__
and `paper<https://github.com/hendrycks/GELUs>`__.
.. math::
\text{QGELU}(x) \approx x * \sigma(1.702 * x)
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.qgelu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dqgelu(*args, **kwargs)
class QGEGLU(_ActivationOperation):
r"""Quick Gaussian Error Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{GEGLU}(a,b) = \text{GELU}(a) * b
\text{QGEGLU}(a,b) = \text{QGELU}(a) * b
where
.. math::
\text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right)
\text{QGELU}(x) \approx x * \sigma(1.702 * x)
.. warning::
......@@ -187,19 +241,33 @@ class GEGLU(_ActivationOperation):
the first half of the input tensor, while PyTorch applies it to
the second half.
See `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.qgeglu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dqgeglu(*args, **kwargs)
class ReLU(_ActivationOperation):
r"""Rectified Linear Unit
.. math::
\text{ReLU}(x) = \max(x,0)
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.geglu(*args, **kwargs)
return tex.relu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dgeglu(*args, **kwargs)
return tex.drelu(*args, **kwargs)
class ReGLU(_ActivationOperation):
r"""Rectified gated linear unit
r"""Rectified Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
......@@ -227,6 +295,67 @@ class ReGLU(_ActivationOperation):
return tex.dreglu(*args, **kwargs)
class SReLU(_ActivationOperation):
r"""Squared Rectified Linear Unit
.. math::
\text{SReLU}(x) = \max(x^2,0)
See `Primer: Searching for Efficient Transformers for Language Modeling<https://arxiv.org/abs/2109.08668v2>`__.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.srelu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dsrelu(*args, **kwargs)
class SReGLU(_ActivationOperation):
r"""Squared Rectified Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{SReGLU}(a,b) = \max(a^2,0) * b
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.sreglu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dsreglu(*args, **kwargs)
class SiLU(_ActivationOperation):
r"""Sigmoid Linear Unit
.. math::
\text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)}
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.silu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dsilu(*args, **kwargs)
class SwiGLU(_ActivationOperation):
r"""Swish gated linear unit
......
......@@ -12,26 +12,32 @@ from typing import Any, Optional
import torch
from transformer_engine.pytorch.module.base import get_workspace
from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import (
CudaRNGStatesTracker,
gather_along_first_dim,
reduce_scatter_along_first_dim,
)
from ...fp8 import FP8GlobalStateManager, Recipe
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...module.base import (
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
get_dummy_wgrad,
get_workspace,
)
from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor
from ...utils import (
canonicalize_device,
canonicalize_dtype,
clear_tensor_data,
devices_match,
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor
def _wait_async(handle: Optional[Any]) -> None:
......@@ -73,7 +79,8 @@ class BasicLinear(BasicOperation):
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful.
meaningful. This is primarily intented to integrate with
Megatron-LM.
userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly
......@@ -958,6 +965,8 @@ class BasicLinear(BasicOperation):
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
ctx.save_for_backward(x_local, w)
ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizer = input_quantizer
......@@ -979,20 +988,22 @@ class BasicLinear(BasicOperation):
# Saved tensors from forward pass
(x_local, w) = ctx.saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = self._accumulate_into_main_grad
grad_weight = None
if ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(self.weight, "__fsdp_param__"):
self.weight.main_grad = self.weight.get_main_grad()
if not hasattr(self.weight, "main_grad"):
weight_param = self.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = self.weight.main_grad.detach()
grad_weight = weight_param.main_grad.detach()
else:
accumulate_into_main_grad = False
......@@ -1019,6 +1030,17 @@ class BasicLinear(BasicOperation):
# Clear input tensor if possible
clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = self.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [grad_weight]
......@@ -8,12 +8,12 @@ from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from .._common import maybe_autocast_dtype, maybe_dequantize
from ..op import BasicOperation, OperationContext
class Dropout(BasicOperation):
......@@ -27,7 +27,7 @@ class Dropout(BasicOperation):
def __init__(self, p: float) -> None:
super().__init__()
self.dropout_probability = p
self.dropout_probability: float = p
def op_forward(
self,
......@@ -37,21 +37,46 @@ class Dropout(BasicOperation):
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Compute dropout if training
out = input_
is_training = self.training
mask = None
if is_training:
# Output dtype
dtype = maybe_autocast_dtype(default_dtype=input_.dtype)
# Choose implementation
impl = None
if not self.training:
impl = "evaluation"
elif input_.numel() % 16 == 0 and dtype in (torch.float16, torch.bfloat16):
impl = "fused"
else:
impl = "unfused"
# Perform dropout
out: torch.Tensor
mask: Optional[torch.Tensor] = None
if impl == "evaluation":
out = input_
elif impl == "fused":
x = input_
if not isinstance(x, Float8TensorBase):
x = maybe_dequantize(x, dtype=dtype)
out, mask = tex.dropout_fwd(x, self.dropout_probability)
elif impl == "unfused":
x = maybe_dequantize(input_, dtype=dtype)
keep_prob = 1 - self.dropout_probability
mask = torch.empty_like(input_)
mask = torch.empty_like(x)
mask.bernoulli_(keep_prob)
mask *= 1 / keep_prob
out = out * mask
out = x * mask
else:
raise ValueError(f"Unsupported forward implementation {impl}")
# Save context for backward
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(mask)
ctx.save_for_backward(mask)
ctx.is_training = is_training
ctx.impl = impl
ctx.dropout_probability = self.dropout_probability
ctx.dtype = dtype
return out
......@@ -60,8 +85,21 @@ class Dropout(BasicOperation):
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Saved tensors from forward pass
(mask,) = ctx.saved_tensors
grad_input = grad_output
if ctx.is_training:
grad_input = grad_input * mask
# Perform dropout backward pass
grad_input: torch.Tensor
if ctx.impl == "evaluation":
grad_input = grad_output
elif ctx.impl == "fused":
dy = maybe_dequantize(grad_output, dtype=ctx.dtype)
grad_input = tex.dropout_bwd(dy, mask, ctx.dropout_probability)
elif ctx.impl == "unfused":
dy = maybe_dequantize(grad_output, dtype=ctx.dtype)
grad_input = dy * mask
else:
raise ValueError(f"Unsupported backward implementation {ctx.impl}")
return grad_input, ()
......@@ -10,10 +10,8 @@ import os
import torch
from ...utils import clear_tensor_data
from ... import torch_version
from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...jit import (
l2normalization_fused,
l2normalization_fwd_fused,
......@@ -22,6 +20,9 @@ from ...jit import (
warmup_jit_l2normalization_all_dtypes,
)
from ...tensor import Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize
class L2Normalization(BasicOperation):
......@@ -101,6 +102,8 @@ class L2Normalization(BasicOperation):
# Save state for backward pass
if requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rsqrt_norm)
ctx.save_for_backward(x, rsqrt_norm)
return y
......
......@@ -14,6 +14,9 @@ import torch
from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import (
canonicalize_device,
canonicalize_dtype,
......@@ -22,8 +25,6 @@ from ...utils import (
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
class LayerNorm(BasicOperation):
......@@ -215,6 +216,8 @@ class LayerNorm(BasicOperation):
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, means, rstdevs)
ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype
......
......@@ -14,6 +14,9 @@ import torch
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import (
canonicalize_device,
canonicalize_dtype,
......@@ -22,8 +25,6 @@ from ...utils import (
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
class RMSNorm(BasicOperation):
......@@ -196,6 +197,8 @@ class RMSNorm(BasicOperation):
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rstdevs)
ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype
......
......@@ -8,6 +8,10 @@ from .backward_activation_bias import (
BackwardActivationBias,
fuse_backward_activation_bias,
)
from .backward_add_rmsnorm import (
BackwardAddRMSNorm,
fuse_backward_add_rmsnorm,
)
from .backward_linear_add import (
BackwardLinearAdd,
fuse_backward_linear_add,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused backward RMNorm + add."""
from __future__ import annotations
from typing import Optional
import math
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.ops.basic import MakeExtraOutput, RMSNorm
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data
from .._common import maybe_dequantize
class BackwardAddRMSNorm(FusedOperation):
"""Fused backward RMNorm + add"""
def __init__(self, *, add: MakeExtraOutput, rmsnorm: RMSNorm):
super().__init__((add, rmsnorm))
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
list[tuple[Optional[torch.Tensor], ...]],
list[tuple[()]],
]:
# Get basic operations
rmsnorm_op = self.basic_ops[1]
rmsnorm_op_ctx = basic_op_ctxs[0]
# Saved tensors from forward pass
x, rstdevs = rmsnorm_op_ctx.saved_tensors
# Tensor dims
weight_dims = rmsnorm_op.weight.size()
inner_dim = math.prod(weight_dims)
# Check input tensors
dtype = rmsnorm_op_ctx.dtype
extra_grad = basic_op_grad_extra_outputs[1][0]
dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size())
w = maybe_dequantize(rmsnorm_op.weight, dtype).view((inner_dim,))
add = maybe_dequantize(extra_grad.contiguous(), dtype).view(x.size())
# Compute RMSNorm backward pass
dx, dw = tex.rmsnorm_bwd_add(
dy,
x,
add,
rstdevs,
w,
rmsnorm_op._sm_margins["backward"],
rmsnorm_op.zero_centered_gamma,
)
# Clear saved tensors if possible
clear_tensor_data(x)
clear_tensor_data(rstdevs)
# Reshape results
grad_input = dx.view(grad_output.size())
grad_weight = dw.view(weight_dims)
return grad_input, [(grad_weight,), ()], [(), ()]
def fuse_backward_add_rmsnorm(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward RMNorm + add
Parameters
----------
ops: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, RMSNorm):
continue
# Check if second op is "make extra output"
op, _ = ops[0]
if not isinstance(op, MakeExtraOutput):
continue
if op._in_place:
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardAddRMSNorm(
rmsnorm=window[0][0],
add=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
......@@ -9,13 +9,10 @@ from typing import Optional
import torch
from transformer_engine.pytorch.ops.basic import BasicLinear, MakeExtraOutput
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...module.base import get_dummy_wgrad
from ...utils import clear_tensor_data
from ..basic import BasicLinear, MakeExtraOutput
from ..op import FusedOperation, FusibleOperation, OperationContext
class BackwardLinearAdd(FusedOperation):
......@@ -53,20 +50,22 @@ class BackwardLinearAdd(FusedOperation):
# Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"):
linear_op.weight.main_grad = linear_op.weight.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"):
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = linear_op.weight.main_grad.detach()
grad_weight = weight_param.main_grad.detach()
else:
accumulate_into_main_grad = False
......@@ -92,12 +91,23 @@ class BackwardLinearAdd(FusedOperation):
grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=linear_op_ctx.grad_input_quantizer,
)
if accumulate_into_main_grad:
grad_weight = None
# Clear input tensor if possible
clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [(grad_weight,), ()], [(), ()]
......
......@@ -9,13 +9,10 @@ from typing import Optional
import torch
from ..basic import BasicLinear, ConstantScale
from ..op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...module.base import get_dummy_wgrad
from ...utils import clear_tensor_data
from ..basic import BasicLinear, ConstantScale
from ..op import FusedOperation, FusibleOperation, OperationContext
class BackwardLinearScale(FusedOperation):
......@@ -54,20 +51,22 @@ class BackwardLinearScale(FusedOperation):
# Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"):
linear_op.weight.main_grad = linear_op.weight.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"):
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = linear_op.weight.main_grad.detach()
grad_weight = weight_param.main_grad.detach()
else:
accumulate_into_main_grad = False
......@@ -92,12 +91,23 @@ class BackwardLinearScale(FusedOperation):
grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=linear_op_ctx.grad_input_quantizer,
)
if accumulate_into_main_grad:
grad_weight = None
# Clear input tensor if possible
clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [(), (grad_weight,)], [(), ()]
......
......@@ -10,14 +10,11 @@ from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import BasicLinear, Bias
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...fp8 import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import BasicLinear, Bias
from ..op import FusedOperation, FusibleOperation, OperationContext
class ForwardLinearBiasActivation(FusedOperation):
......@@ -121,6 +118,8 @@ class ForwardLinearBiasActivation(FusedOperation):
# Save state for backward pass
if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
......
......@@ -10,14 +10,11 @@ from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import AddExtraInput, BasicLinear, Bias
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from transformer_engine.pytorch.tensor import Quantizer
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...fp8 import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, Bias
from ..op import FusedOperation, FusibleOperation, OperationContext
class ForwardLinearBiasAdd(FusedOperation):
......@@ -118,6 +115,8 @@ class ForwardLinearBiasAdd(FusedOperation):
# Save state for backward pass
if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
......
......@@ -10,14 +10,15 @@ from typing import Any, Optional
import torch
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...fp8 import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, ConstantScale
from ..op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...tensor import Quantizer
class ForwardLinearScaleAdd(FusedOperation):
......@@ -95,6 +96,8 @@ class ForwardLinearScaleAdd(FusedOperation):
# Save state for backward pass
if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
......
......@@ -14,11 +14,12 @@ from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_exter
from ...cpp_extensions import general_gemm
from ...distributed import get_distributed_world_size
from ...module.base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad,
get_ub,
get_workspace,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...tensor.quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
......@@ -240,16 +241,16 @@ class UserbuffersBackwardLinear(FusedOperation):
with_dgrad_all_gather_x = False
with_wgrad_reduce_scatter_dx = False
if tensor_parallel_mode == "row":
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute)
ub_type_dgrad = CommOverlapType.AG
with_dgrad_all_gather_dy = True
elif tensor_parallel_mode == "column":
if input_requires_grad and weight_requires_grad:
with_bulk_overlap = True
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute)
ub_type_dgrad = CommOverlapType.AG
with_dgrad_all_gather_x = True
ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad")
ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad", with_quantized_compute)
ub_type_wgrad = CommOverlapType.RS
with_wgrad_reduce_scatter_dx = True
if ub_comm_wgrad.is_fp8_ubuf():
......@@ -257,7 +258,7 @@ class UserbuffersBackwardLinear(FusedOperation):
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
else:
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute)
ub_type_dgrad = CommOverlapType.RS
with_dgrad_reduce_scatter_dx = True
if ub_comm_dgrad.is_fp8_ubuf():
......@@ -408,7 +409,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = ub_comm_dgrad.get_communication_stream()
ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad")
ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad", with_quantized_compute)
grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -513,20 +514,22 @@ class UserbuffersBackwardLinear(FusedOperation):
# Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"):
linear_op.weight.main_grad = linear_op.weight.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"):
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = linear_op.weight.main_grad.detach()
grad_weight = weight_param.main_grad.detach()
else:
accumulate_into_main_grad = False
......@@ -558,10 +561,21 @@ class UserbuffersBackwardLinear(FusedOperation):
# Clear input tensor if possible
clear_tensor_data(x_local)
# Return gradients
grad_params = [() for _ in range(len(self.basic_ops))]
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
# Return gradients
grad_params = [() for _ in range(len(self.basic_ops))]
grad_params[self._op_idxs["linear"]] = (grad_weight,)
if bias_op is not None:
grad_params[self._op_idxs["bias"]] = (grad_bias,)
......
......@@ -12,6 +12,7 @@ import torch
from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import get_distributed_world_size
from ...fp8 import FP8GlobalStateManager
from ...module.base import (
......@@ -189,7 +190,7 @@ class UserbuffersForwardLinear(FusedOperation):
output_quantizer = None
# Get Userbuffers communicator
ub_comm = get_ub(ub_comm_name + "_fprop")
ub_comm = get_ub(ub_comm_name + "_fprop", with_quantized_compute)
with_ub_all_gather = tensor_parallel_mode == "column"
with_ub_reduce_scatter = tensor_parallel_mode == "row"
ub_type = CommOverlapType.AG if with_ub_all_gather else CommOverlapType.RS
......@@ -353,6 +354,8 @@ class UserbuffersForwardLinear(FusedOperation):
# Save state for backward pass
if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment