Commit 27ddce40 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents d262ef4c 5b3092a0
...@@ -94,39 +94,45 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): ...@@ -94,39 +94,45 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
# bf16 (recipe is None): # bf16 (recipe is None):
return { return {
"gelu": (tex.gelu, tex.dgelu, None), "gelu": (tex.gelu, tex.dgelu, None),
"relu": (tex.relu, tex.drelu, None),
"geglu": (tex.geglu, tex.dgeglu, None), "geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None), "qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"srelu": (tex.srelu, tex.dsrelu, None), "srelu": (tex.srelu, tex.dsrelu, None),
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
} }
if recipe.delayed() or recipe.mxfp8(): if recipe.delayed() or recipe.mxfp8():
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
return { return {
"gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu),
"relu": (tex.relu, tex.drelu, tex.dbias_drelu),
"geglu": (tex.geglu, tex.dgeglu, None), "geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, tex.dbias_drelu),
"reglu": (tex.reglu, tex.dreglu, None),
"srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu),
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, tex.dbias_dsilu),
"swiglu": (tex.swiglu, tex.dswiglu, None),
} }
# no activation fusion written yet # no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling: [] # Per-tensor current scaling or fp8 blockwise scaling: []
if recipe.float8_current_scaling() or recipe.float8_block_scaling(): if recipe.float8_current_scaling() or recipe.float8_block_scaling():
return { return {
"gelu": (tex.gelu, tex.dgelu, None), "gelu": (tex.gelu, tex.dgelu, None),
"relu": (tex.relu, tex.drelu, None),
"geglu": (tex.geglu, tex.dgeglu, None), "geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None), "qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"srelu": (tex.srelu, tex.dsrelu, None), "srelu": (tex.srelu, tex.dsrelu, None),
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
} }
raise NotImplementedError(f"Unhandled recipe type {recipe}") raise NotImplementedError(f"Unhandled recipe type {recipe}")
...@@ -308,7 +314,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -308,7 +314,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag: if ub_overlap_ag:
# Copy into Userbuffers buffer # Copy into Userbuffers buffer
ub_obj_lnout = get_ub("fc1_fprop") ub_obj_lnout = get_ub("fc1_fprop", fp8)
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_lnout, ub_obj_lnout,
ln_out, ln_out,
...@@ -446,20 +452,25 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -446,20 +452,25 @@ class _LayerNormMLP(torch.autograd.Function):
act_out = activation_func(fc1_out, None) act_out = activation_func(fc1_out, None)
act_out = tex.quantize(act_out, fc2_input_quantizer) act_out = tex.quantize(act_out, fc2_input_quantizer)
else: else:
act_out = activation_func(fc1_out, fc2_input_quantizer) if fp8_calibration:
act_out = activation_func(fc1_out, None)
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)
if not is_grad_enabled: if not is_grad_enabled:
clear_tensor_data(fc1_out) clear_tensor_data(fc1_out)
if fp8_calibration: if not fp8 and fp8_calibration:
fc2_input_quantizer.calibrate(act_out) if fc2_input_quantizer is not None:
fc2_weight_quantizer.calibrate(fc2_weight) fc2_input_quantizer.calibrate(act_out)
if fc2_weight_quantizer is not None:
fc2_weight_quantizer.calibrate(fc2_weight)
# Configure Userbuffers reduce-scatter if needed # Configure Userbuffers reduce-scatter if needed
ub_obj_fc2out = None ub_obj_fc2out = None
reduce_scatter_out = None reduce_scatter_out = None
if ub_overlap_rs: if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop") ub_obj_fc2out = get_ub("fc2_fprop", fp8)
dim_size = list(act_out.size()) dim_size = list(act_out.size())
dim_size[0] //= tp_world_size dim_size[0] //= tp_world_size
dim_size[-1] = fc2_weight.size(0) dim_size[-1] = fc2_weight.size(0)
...@@ -741,7 +752,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -741,7 +752,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
ub_obj_fc2_dgrad = None ub_obj_fc2_dgrad = None
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
ub_obj_fc2_dgrad = get_ub("fc2_dgrad") ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8)
ctx.ub_obj_gradout = ub_obj_fc2_dgrad ctx.ub_obj_gradout = ub_obj_fc2_dgrad
( (
grad_output, grad_output,
...@@ -765,7 +776,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -765,7 +776,7 @@ class _LayerNormMLP(torch.autograd.Function):
# wgrad GEMM requires input with column-wise usage # wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
ub_obj_fc1_dgrad = get_ub("fc1_dgrad") ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8)
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_fc1_dgrad, ub_obj_fc1_dgrad,
ln_out, ln_out,
...@@ -870,7 +881,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -870,7 +881,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc2_dgrad.get_communication_stream() ub_obj_fc2_dgrad.get_communication_stream()
) )
ub_obj_fc2_wgrad = get_ub("fc2_wgrad") ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8)
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
...@@ -1045,16 +1056,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1045,16 +1056,16 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]]
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
# Overlap DGRAD+RS # Overlap DGRAD+RS
ub_obj_fc1_dgrad = get_ub("fc1_dgrad") ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8)
ub_type_fc1_dgrad = tex.CommOverlapType.RS ub_type_fc1_dgrad = tex.CommOverlapType.RS
else: else:
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
# Overlap ln_out all-gather with DGRAD compute # Overlap ln_out all-gather with DGRAD compute
ub_obj_fc1_dgrad = get_ub("fc1_dgrad") ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8)
ub_type_fc1_dgrad = tex.CommOverlapType.AG ub_type_fc1_dgrad = tex.CommOverlapType.AG
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute # Overlap FC1 DGRAD reduce-scatter with WGRAD compute
ub_obj_fc1_wgrad = get_ub("fc1_wgrad") ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8)
ub_type_fc1_wgrad = tex.CommOverlapType.RS ub_type_fc1_wgrad = tex.CommOverlapType.RS
# -------------------------------------------------- # --------------------------------------------------
...@@ -1402,7 +1413,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1402,7 +1413,7 @@ class _LayerNormMLP(torch.autograd.Function):
class LayerNormMLP(TransformerEngineBaseModule): class LayerNormMLP(TransformerEngineBaseModule):
r""" r"""
Applies layer normalization on the input followed by the MLP module, consisting of Applies layer normalization on the input followed by the MLP module, consisting of
2 successive linear transformations, separated by the GeLU activation. 2 successive linear transformations, separated by the activation function.
Parameters Parameters
---------- ----------
...@@ -1418,7 +1429,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1418,7 +1429,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
type of normalization applied. type of normalization applied.
activation : str, default = 'gelu' activation : str, default = 'gelu'
activation function used. activation function used.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu', 'qgelu', 'srelu'. Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'.
init_method : Callable, default = `None` init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`. used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
...@@ -1559,7 +1571,11 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1559,7 +1571,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.gemm_gelu_fusion = ( self.gemm_gelu_fusion = (
bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0")))
and self.activation == "gelu" and self.activation == "gelu"
and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) and all(
("fc1_fprop", use_fp8) not in _ub_communicators
or not get_ub("fc1_fprop", use_fp8).is_atomic_gemm()
for use_fp8 in [False, True]
)
) )
self.name = name self.name = name
...@@ -1619,7 +1635,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1619,7 +1635,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_bias = None self.layer_norm_bias = None
# FC1 init # FC1 init
if self.activation in ["reglu", "geglu", "qgeglu", "swiglu"]: if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu"]:
fc1_output_features = 2 * self.size_per_partition fc1_output_features = 2 * self.size_per_partition
else: else:
fc1_output_features = self.size_per_partition fc1_output_features = self.size_per_partition
...@@ -1777,7 +1793,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1777,7 +1793,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fp8_output = False fp8_output = False
if self.ub_overlap_rs: if self.ub_overlap_rs:
if get_ub("fc2_fprop").is_fp8_ubuf(): if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
fp8_output = True fp8_output = True
with torch.cuda.device( with torch.cuda.device(
...@@ -1915,7 +1931,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1915,7 +1931,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer, fc2_grad_output_quantizer,
) = [None] * 10 ) = [None] * 10
fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers() fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers()
if self.fp8: if self.fp8 or self.fp8_calibration:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = True fc1_input_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
...@@ -2001,14 +2017,17 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2001,14 +2017,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation_map = { activation_map = {
"gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"relu": torch.nn.functional.relu,
"geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], "geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], "qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh") "qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh")
* x.chunk(2, -1)[1], * x.chunk(2, -1)[1],
"qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "relu": torch.nn.functional.relu,
"srelu": torch.nn.functional.softplus, "reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"srelu": lambda x: torch.nn.functional.relu(x) ** 2,
"sreglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) ** 2
* x.chunk(2, -1)[1],
"silu": torch.nn.functional.silu,
"swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
} }
if self.activation not in activation_map: if self.activation not in activation_map:
raise ValueError(f"Unsupported activation in onnx export: {self.activation}") raise ValueError(f"Unsupported activation in onnx export: {self.activation}")
...@@ -2129,7 +2148,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2129,7 +2148,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
def _get_weight_quantizers(self) -> List[Quantizer]: def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module.""" """Get the weight quantizers of the module."""
if not self.fp8: if not self.fp8 and not self.fp8_calibration:
return [None, None] return [None, None]
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True fc1_weight_quantizer.internal = True
...@@ -2182,10 +2201,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2182,10 +2201,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.fc1_bias.grad is None: if self.fc1_bias.grad is None:
self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype) self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype)
if not self.fuse_wgrad_accumulation: if not self.fuse_wgrad_accumulation:
if self.fc2_weight.grad is None: self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype)
self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype) self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype)
if self.fc1_weight.grad is None:
self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype)
del fc2_bias_grad_ del fc2_bias_grad_
del fc2_wgrad del fc2_wgrad
del fc1_wgrad del fc1_wgrad
......
...@@ -147,10 +147,10 @@ class _Linear(torch.autograd.Function): ...@@ -147,10 +147,10 @@ class _Linear(torch.autograd.Function):
ub_obj = None ub_obj = None
ub_type = None ub_type = None
if ub_overlap_rs_fprop: if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop") ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.RS ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop: elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop") ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.AG ub_type = tex.CommOverlapType.AG
# ------------------------------------------------------ # ------------------------------------------------------
...@@ -319,6 +319,13 @@ class _Linear(torch.autograd.Function): ...@@ -319,6 +319,13 @@ class _Linear(torch.autograd.Function):
# Finished forward GEMM... # Finished forward GEMM...
# ------------------------------------------------------ # ------------------------------------------------------
# Deallocate GEMM input tensor if no longer needed
# TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically
# deallocated by GC. Manually deallocating is a temporary hack.
if with_input_all_gather_nccl:
clear_tensor_data(inputmat_total)
inputmat_total = None
# ------------------------------------------------------ # ------------------------------------------------------
# Prepare output tensor # Prepare output tensor
# Note: Perform tensor-parallel communication # Note: Perform tensor-parallel communication
...@@ -544,23 +551,23 @@ class _Linear(torch.autograd.Function): ...@@ -544,23 +551,23 @@ class _Linear(torch.autograd.Function):
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute # Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad: elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute # Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS ub_type_dgrad = tex.CommOverlapType.RS
else: else:
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute # Overlap inputmat all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG ub_type_dgrad = tex.CommOverlapType.AG
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute # Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ub_type_wgrad = tex.CommOverlapType.RS ub_type_wgrad = tex.CommOverlapType.RS
# -------------------------------------------------- # --------------------------------------------------
...@@ -793,7 +800,7 @@ class _Linear(torch.autograd.Function): ...@@ -793,7 +800,7 @@ class _Linear(torch.autograd.Function):
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM # This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad") ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
...@@ -905,9 +912,16 @@ class _Linear(torch.autograd.Function): ...@@ -905,9 +912,16 @@ class _Linear(torch.autograd.Function):
grad_bias = grad_bias_ grad_bias = grad_bias_
del grad_bias_ del grad_bias_
# Deallocate input tensor if permitted # Deallocate tensors if permitted
if ctx.owns_input: if ctx.owns_input:
# Input tensor is internal
clear_tensor_data(inputmat_total) clear_tensor_data(inputmat_total)
elif ctx.backward_input_needs_gather:
# Gathered input tensor is internal
clear_tensor_data(inputmat_total)
if ctx.parallel_mode == "row" and ctx.sequence_parallel:
# Gathered grad output tensor is internal
clear_tensor_data(grad_output)
# Update grad input if overlapping reduce-scatter with wgrad GEMM # Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
...@@ -1404,10 +1418,14 @@ class Linear(TransformerEngineBaseModule): ...@@ -1404,10 +1418,14 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch = False is_first_microbatch = False
if self.ub_overlap_rs_fprop: if self.ub_overlap_rs_fprop:
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf(): if get_ub(
self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
fp8_output = True fp8_output = True
if self.ub_overlap_rs_dgrad: if self.ub_overlap_rs_dgrad:
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): if get_ub(
self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
fp8_grad = True fp8_grad = True
with torch.cuda.device( with torch.cuda.device(
...@@ -1666,7 +1684,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1666,7 +1684,7 @@ class Linear(TransformerEngineBaseModule):
def _get_weight_quantizers(self) -> List[Quantizer]: def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module.""" """Get the weight quantizers of the module."""
if not self.fp8: if not self.fp8 and not self.fp8_calibration:
return [None] return [None]
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True weight_quantizer.internal = True
......
...@@ -112,7 +112,9 @@ schema = defs.OpSchema( ...@@ -112,7 +112,9 @@ schema = defs.OpSchema(
doc="TRT FP8 Quantize Linear used for inference.", doc="TRT FP8 Quantize Linear used for inference.",
inputs=[ inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"), defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"),
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for quantization"), defs.OpSchema.FormalParameter(
"scale_inv", "tensor(float)", "Inverse scale factor for quantization"
),
], ],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")], outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")],
) )
...@@ -126,11 +128,10 @@ TRT_FP8QuantizeLinear = onnxscript.values.Op( ...@@ -126,11 +128,10 @@ TRT_FP8QuantizeLinear = onnxscript.values.Op(
@torch.library.custom_op("tex::fp8_dequantize", mutates_args=[]) @torch.library.custom_op("tex::fp8_dequantize", mutates_args=[])
def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor: def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor:
"""Dequantize from Float8Tensor used for inference.""" """Dequantize from Float8Tensor used for inference."""
scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device)
quantizer = Float8Quantizer( quantizer = Float8Quantizer(
scale_tensor, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3 1 / scale_inv, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3
) )
quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32) quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32)
return quantizer_tensor.dequantize() return quantizer_tensor.dequantize()
...@@ -143,10 +144,9 @@ def _(tensor: torch.Tensor, _) -> torch.Tensor: ...@@ -143,10 +144,9 @@ def _(tensor: torch.Tensor, _) -> torch.Tensor:
def onnx_dequantize_fp8_symbolic( def onnx_dequantize_fp8_symbolic(
tensor: onnxscript.onnx_types.TensorType, scale: float tensor: onnxscript.onnx_types.TensorType, scale_inv: onnxscript.onnx_types.TensorType
) -> onnxscript.onnx_types.TensorType: ) -> onnxscript.onnx_types.TensorType:
"""Symbolic dequantize from Float8Tensor used for inference.""" """Symbolic dequantize from Float8Tensor used for inference."""
scale_inv = op.Constant(value_float=1 / scale)
return TRT_FP8DequantizeLinear(tensor, scale_inv) return TRT_FP8DequantizeLinear(tensor, scale_inv)
...@@ -157,7 +157,9 @@ schema = defs.OpSchema( ...@@ -157,7 +157,9 @@ schema = defs.OpSchema(
doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.", doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.",
inputs=[ inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"), defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"),
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for dequantization"), defs.OpSchema.FormalParameter(
"scale_inv", "tensor(float)", "Inverse scale factor for dequantization"
),
], ],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")], outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")],
) )
...@@ -166,6 +168,43 @@ TRT_FP8DequantizeLinear = onnxscript.values.Op( ...@@ -166,6 +168,43 @@ TRT_FP8DequantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema
) )
# ONNX FP8 Current Scaling Quantization
@torch.library.custom_op("tex::fp8_cs_quantize", mutates_args=[])
def onnx_cs_quantize_fp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantize to FP8 with current scaling; returns (uint8, scale_inv)."""
if tensor.dtype != torch.float32:
tensor = tensor.to(torch.float32)
amax = tensor.abs().max()
eps = torch.tensor(1e-12, dtype=torch.float32, device=tensor.device)
amax = torch.maximum(amax, eps)
fp8_max = torch.tensor(448, dtype=torch.float32, device=tensor.device)
scale = fp8_max / amax
q = torch.ops.tex.fp8_quantize(tensor, scale)
scale_inv = 1 / scale
return q, scale_inv
@onnx_cs_quantize_fp8_op.register_fake
def _(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device), torch.ones(
1, dtype=torch.float32, device=tensor.device
)
def onnx_quantize_fp8_cs_symbolic(
tensor: onnxscript.onnx_types.TensorType,
):
"""Symbolic quantize with current scaling; computes scale_inv from tensor."""
# scale_inv = 1 / max(abs(tensor))
amax = op.ReduceMax(op.Abs(tensor), keepdims=0)
eps = op.Constant(value_float=1.0e-12)
amax = op.Max(amax, eps)
scale_inv = op.Div(amax, op.Constant(value_float=448.0))
q = TRT_FP8QuantizeLinear(tensor, scale_inv)
return q, scale_inv
# ONNX MXFP8 Quantization # ONNX MXFP8 Quantization
...@@ -194,12 +233,12 @@ def onnx_quantize_mxfp8_symbolic( ...@@ -194,12 +233,12 @@ def onnx_quantize_mxfp8_symbolic(
tensor: onnxscript.onnx_types.TensorType, tensor: onnxscript.onnx_types.TensorType,
) -> Tuple[onnxscript.onnx_types.TensorType, onnxscript.onnx_types.TensorType]: ) -> Tuple[onnxscript.onnx_types.TensorType, onnxscript.onnx_types.TensorType]:
"""Symbolic quantize to MXFP8Tensor used for inference.""" """Symbolic quantize to MXFP8Tensor used for inference."""
tensor_out, scale_inv_out = TRT_MXFP8QuantizeLinear(tensor) tensor_out, scale_inv_out = TRT_MXFP8DynamicQuantize(tensor)
return tensor_out, scale_inv_out return tensor_out, scale_inv_out
schema = defs.OpSchema( schema = defs.OpSchema(
name="TRT_MXFP8QuantizeLinear", name="TRT_MXFP8DynamicQuantize",
domain="trt", domain="trt",
since_version=1, since_version=1,
doc="TRT MXFP8 Quantize Linear used for inference.", doc="TRT MXFP8 Quantize Linear used for inference.",
...@@ -214,8 +253,8 @@ schema = defs.OpSchema( ...@@ -214,8 +253,8 @@ schema = defs.OpSchema(
], ],
) )
TRT_MXFP8QuantizeLinear = onnxscript.values.Op( TRT_MXFP8DynamicQuantize = onnxscript.values.Op(
opset=trt_opset, name="TRT_MXFP8QuantizeLinear", op_schema=schema opset=trt_opset, name="TRT_MXFP8DynamicQuantize", op_schema=schema
) )
...@@ -356,6 +395,7 @@ te_translation_table = { ...@@ -356,6 +395,7 @@ te_translation_table = {
torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic, torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic,
torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic, torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic,
torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic, torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic,
torch.ops.tex.fp8_cs_quantize.default: onnx_quantize_fp8_cs_symbolic,
torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic, torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic,
torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic, torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic,
torch.ops.tex.layernorm.default: onnx_layernorm_symbolic, torch.ops.tex.layernorm.default: onnx_layernorm_symbolic,
......
...@@ -29,7 +29,9 @@ def maybe_dequantize( ...@@ -29,7 +29,9 @@ def maybe_dequantize(
if is_quantized_tensor(tensor): if is_quantized_tensor(tensor):
return tensor.dequantize(dtype=dtype) return tensor.dequantize(dtype=dtype)
if dtype is not None and tensor.dtype != dtype: if dtype is not None and tensor.dtype != dtype:
return tensor.to(dtype) tensor = tensor.to(dtype)
if not tensor.is_contiguous():
tensor = tensor.contiguous()
return tensor return tensor
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Single tensor operations supported by the operation fuser.""" """Single tensor operations supported by the operation fuser."""
from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU
from .add_extra_input import AddExtraInput from .add_extra_input import AddExtraInput
from .all_gather import AllGather from .all_gather import AllGather
from .all_reduce import AllReduce from .all_reduce import AllReduce
......
...@@ -11,11 +11,25 @@ from typing import Optional ...@@ -11,11 +11,25 @@ from typing import Optional
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize from .._common import maybe_dequantize
__all__ = [
"GELU",
"GEGLU",
"QGELU",
"QGEGLU",
"ReLU",
"ReGLU",
"SReLU",
"SReGLU",
"SiLU",
"SwiGLU",
]
class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
r"""Apply activation function r"""Apply activation function
...@@ -97,6 +111,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -97,6 +111,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Save state for backward pass # Save state for backward pass
if ctx.requires_grad: if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
ctx.save_for_backward(x) ctx.save_for_backward(x)
ctx.dtype = dtype ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
...@@ -147,37 +163,75 @@ class GELU(_ActivationOperation): ...@@ -147,37 +163,75 @@ class GELU(_ActivationOperation):
return tex.dgelu(*args, **kwargs) return tex.dgelu(*args, **kwargs)
class ReLU(_ActivationOperation): class GEGLU(_ActivationOperation):
r"""Rectified linear unit r"""Gaussian Error Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math:: .. math::
\text{ReLU}(x) = \max(x,0) \text{GEGLU}(a,b) = \text{GELU}(a) * b
where
.. math::
\text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right)
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
See `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__.
""" """
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.relu(*args, **kwargs) return tex.geglu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.drelu(*args, **kwargs) return tex.dgeglu(*args, **kwargs)
class GEGLU(_ActivationOperation): class QGELU(_ActivationOperation):
r"""Gaussian error gated linear unit r"""Quick Gaussian Error Linear Unit
Quick GELU from `HuggingFace<https://github.com/huggingface/transformers/blob/3e93dd295b5343557a83bc07b0b2ea64c926f9b4/src/transformers/activations.py#L90>`__
and `paper<https://github.com/hendrycks/GELUs>`__.
.. math::
\text{QGELU}(x) \approx x * \sigma(1.702 * x)
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.qgelu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dqgelu(*args, **kwargs)
class QGEGLU(_ActivationOperation):
r"""Quick Gaussian Error Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b` The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed: along the last dimension and the following is computed:
.. math:: .. math::
\text{GEGLU}(a,b) = \text{GELU}(a) * b \text{QGEGLU}(a,b) = \text{QGELU}(a) * b
where where
.. math:: .. math::
\text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) \text{QGELU}(x) \approx x * \sigma(1.702 * x)
.. warning:: .. warning::
...@@ -187,19 +241,33 @@ class GEGLU(_ActivationOperation): ...@@ -187,19 +241,33 @@ class GEGLU(_ActivationOperation):
the first half of the input tensor, while PyTorch applies it to the first half of the input tensor, while PyTorch applies it to
the second half. the second half.
See `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__. """
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.qgeglu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dqgeglu(*args, **kwargs)
class ReLU(_ActivationOperation):
r"""Rectified Linear Unit
.. math::
\text{ReLU}(x) = \max(x,0)
""" """
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.geglu(*args, **kwargs) return tex.relu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dgeglu(*args, **kwargs) return tex.drelu(*args, **kwargs)
class ReGLU(_ActivationOperation): class ReGLU(_ActivationOperation):
r"""Rectified gated linear unit r"""Rectified Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b` The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed: along the last dimension and the following is computed:
...@@ -227,6 +295,67 @@ class ReGLU(_ActivationOperation): ...@@ -227,6 +295,67 @@ class ReGLU(_ActivationOperation):
return tex.dreglu(*args, **kwargs) return tex.dreglu(*args, **kwargs)
class SReLU(_ActivationOperation):
r"""Squared Rectified Linear Unit
.. math::
\text{SReLU}(x) = \max(x^2,0)
See `Primer: Searching for Efficient Transformers for Language Modeling<https://arxiv.org/abs/2109.08668v2>`__.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.srelu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dsrelu(*args, **kwargs)
class SReGLU(_ActivationOperation):
r"""Squared Rectified Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{SReGLU}(a,b) = \max(a^2,0) * b
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.sreglu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dsreglu(*args, **kwargs)
class SiLU(_ActivationOperation):
r"""Sigmoid Linear Unit
.. math::
\text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)}
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.silu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dsilu(*args, **kwargs)
class SwiGLU(_ActivationOperation): class SwiGLU(_ActivationOperation):
r"""Swish gated linear unit r"""Swish gated linear unit
......
...@@ -12,26 +12,32 @@ from typing import Any, Optional ...@@ -12,26 +12,32 @@ from typing import Any, Optional
import torch import torch
from transformer_engine.pytorch.module.base import get_workspace
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import ( from ...distributed import (
CudaRNGStatesTracker, CudaRNGStatesTracker,
gather_along_first_dim, gather_along_first_dim,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
) )
from ...fp8 import FP8GlobalStateManager, Recipe from ...fp8 import FP8GlobalStateManager, Recipe
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...module.base import (
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
get_dummy_wgrad,
get_workspace,
)
from ...tensor import Quantizer from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor
from ...utils import ( from ...utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
clear_tensor_data, clear_tensor_data,
devices_match, devices_match,
) )
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor
def _wait_async(handle: Optional[Any]) -> None: def _wait_async(handle: Optional[Any]) -> None:
...@@ -73,7 +79,8 @@ class BasicLinear(BasicOperation): ...@@ -73,7 +79,8 @@ class BasicLinear(BasicOperation):
weight's `main_grad` attribute instead of relying on PyTorch weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be there is no guarantee that `grad` will be set or be
meaningful. meaningful. This is primarily intented to integrate with
Megatron-LM.
userbuffers_options, dict, optional userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly compute using Userbuffers. This feature is highly
...@@ -958,6 +965,8 @@ class BasicLinear(BasicOperation): ...@@ -958,6 +965,8 @@ class BasicLinear(BasicOperation):
# Save state for backward pass # Save state for backward pass
if ctx.requires_grad: if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
ctx.save_for_backward(x_local, w) ctx.save_for_backward(x_local, w)
ctx.with_quantized_compute = with_quantized_compute ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizer = input_quantizer ctx.input_quantizer = input_quantizer
...@@ -979,20 +988,22 @@ class BasicLinear(BasicOperation): ...@@ -979,20 +988,22 @@ class BasicLinear(BasicOperation):
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local, w) = ctx.saved_tensors (x_local, w) = ctx.saved_tensors
# wgrad fusion # Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = self._accumulate_into_main_grad accumulate_into_main_grad = self._accumulate_into_main_grad
grad_weight = None grad_weight = None
if ctx.weight_requires_grad and accumulate_into_main_grad: if ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(self.weight, "__fsdp_param__"): weight_param = self.weight
self.weight.main_grad = self.weight.get_main_grad() if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(self.weight, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
"accumulate_into_main_grad=True, " "accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute" "but weight parameter does not have main_grad attribute"
) )
grad_weight = self.weight.main_grad.detach() grad_weight = weight_param.main_grad.detach()
else: else:
accumulate_into_main_grad = False accumulate_into_main_grad = False
...@@ -1019,6 +1030,17 @@ class BasicLinear(BasicOperation): ...@@ -1019,6 +1030,17 @@ class BasicLinear(BasicOperation):
# Clear input tensor if possible # Clear input tensor if possible
clear_tensor_data(x_local) clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad: if accumulate_into_main_grad:
grad_weight = None grad_weight = None
weight_param = self.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [grad_weight] return grad_input, [grad_weight]
...@@ -8,12 +8,12 @@ from __future__ import annotations ...@@ -8,12 +8,12 @@ from __future__ import annotations
from typing import Optional from typing import Optional
import torch import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.ops.op import ( from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer from ...tensor import Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from .._common import maybe_autocast_dtype, maybe_dequantize
from ..op import BasicOperation, OperationContext
class Dropout(BasicOperation): class Dropout(BasicOperation):
...@@ -27,7 +27,7 @@ class Dropout(BasicOperation): ...@@ -27,7 +27,7 @@ class Dropout(BasicOperation):
def __init__(self, p: float) -> None: def __init__(self, p: float) -> None:
super().__init__() super().__init__()
self.dropout_probability = p self.dropout_probability: float = p
def op_forward( def op_forward(
self, self,
...@@ -37,21 +37,46 @@ class Dropout(BasicOperation): ...@@ -37,21 +37,46 @@ class Dropout(BasicOperation):
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
# Compute dropout if training # Output dtype
out = input_ dtype = maybe_autocast_dtype(default_dtype=input_.dtype)
is_training = self.training
mask = None # Choose implementation
if is_training: impl = None
if not self.training:
impl = "evaluation"
elif input_.numel() % 16 == 0 and dtype in (torch.float16, torch.bfloat16):
impl = "fused"
else:
impl = "unfused"
# Perform dropout
out: torch.Tensor
mask: Optional[torch.Tensor] = None
if impl == "evaluation":
out = input_
elif impl == "fused":
x = input_
if not isinstance(x, Float8TensorBase):
x = maybe_dequantize(x, dtype=dtype)
out, mask = tex.dropout_fwd(x, self.dropout_probability)
elif impl == "unfused":
x = maybe_dequantize(input_, dtype=dtype)
keep_prob = 1 - self.dropout_probability keep_prob = 1 - self.dropout_probability
mask = torch.empty_like(input_) mask = torch.empty_like(x)
mask.bernoulli_(keep_prob) mask.bernoulli_(keep_prob)
mask *= 1 / keep_prob mask *= 1 / keep_prob
out = out * mask out = x * mask
else:
raise ValueError(f"Unsupported forward implementation {impl}")
# Save context for backward # Save context for backward
if ctx.requires_grad: if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(mask)
ctx.save_for_backward(mask) ctx.save_for_backward(mask)
ctx.is_training = is_training ctx.impl = impl
ctx.dropout_probability = self.dropout_probability
ctx.dtype = dtype
return out return out
...@@ -60,8 +85,21 @@ class Dropout(BasicOperation): ...@@ -60,8 +85,21 @@ class Dropout(BasicOperation):
ctx: OperationContext, ctx: OperationContext,
grad_output: torch.Tensor, grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]: ) -> tuple[torch.Tensor, tuple[()]]:
# Saved tensors from forward pass
(mask,) = ctx.saved_tensors (mask,) = ctx.saved_tensors
grad_input = grad_output
if ctx.is_training: # Perform dropout backward pass
grad_input = grad_input * mask grad_input: torch.Tensor
if ctx.impl == "evaluation":
grad_input = grad_output
elif ctx.impl == "fused":
dy = maybe_dequantize(grad_output, dtype=ctx.dtype)
grad_input = tex.dropout_bwd(dy, mask, ctx.dropout_probability)
elif ctx.impl == "unfused":
dy = maybe_dequantize(grad_output, dtype=ctx.dtype)
grad_input = dy * mask
else:
raise ValueError(f"Unsupported backward implementation {ctx.impl}")
return grad_input, () return grad_input, ()
...@@ -10,10 +10,8 @@ import os ...@@ -10,10 +10,8 @@ import os
import torch import torch
from ...utils import clear_tensor_data
from ... import torch_version from ... import torch_version
from .._common import maybe_dequantize from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..op import BasicOperation, OperationContext
from ...jit import ( from ...jit import (
l2normalization_fused, l2normalization_fused,
l2normalization_fwd_fused, l2normalization_fwd_fused,
...@@ -22,6 +20,9 @@ from ...jit import ( ...@@ -22,6 +20,9 @@ from ...jit import (
warmup_jit_l2normalization_all_dtypes, warmup_jit_l2normalization_all_dtypes,
) )
from ...tensor import Quantizer from ...tensor import Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize
class L2Normalization(BasicOperation): class L2Normalization(BasicOperation):
...@@ -101,6 +102,8 @@ class L2Normalization(BasicOperation): ...@@ -101,6 +102,8 @@ class L2Normalization(BasicOperation):
# Save state for backward pass # Save state for backward pass
if requires_grad: if requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rsqrt_norm)
ctx.save_for_backward(x, rsqrt_norm) ctx.save_for_backward(x, rsqrt_norm)
return y return y
......
...@@ -14,6 +14,9 @@ import torch ...@@ -14,6 +14,9 @@ import torch
from transformer_engine_torch import layernorm_bwd, layernorm_fwd from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...constants import TE_DType from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import ( from ...utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
...@@ -22,8 +25,6 @@ from ...utils import ( ...@@ -22,8 +25,6 @@ from ...utils import (
) )
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize from .._common import maybe_autocast_dtype, maybe_dequantize
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
class LayerNorm(BasicOperation): class LayerNorm(BasicOperation):
...@@ -215,6 +216,8 @@ class LayerNorm(BasicOperation): ...@@ -215,6 +216,8 @@ class LayerNorm(BasicOperation):
# Save state for backward pass # Save state for backward pass
if ctx.requires_grad: if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, means, rstdevs)
ctx.save_for_backward(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype ctx.dtype = dtype
......
...@@ -14,6 +14,9 @@ import torch ...@@ -14,6 +14,9 @@ import torch
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...constants import TE_DType from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import ( from ...utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
...@@ -22,8 +25,6 @@ from ...utils import ( ...@@ -22,8 +25,6 @@ from ...utils import (
) )
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize from .._common import maybe_autocast_dtype, maybe_dequantize
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
class RMSNorm(BasicOperation): class RMSNorm(BasicOperation):
...@@ -196,6 +197,8 @@ class RMSNorm(BasicOperation): ...@@ -196,6 +197,8 @@ class RMSNorm(BasicOperation):
# Save state for backward pass # Save state for backward pass
if ctx.requires_grad: if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rstdevs)
ctx.save_for_backward(x, rstdevs) ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype ctx.dtype = dtype
......
...@@ -8,6 +8,10 @@ from .backward_activation_bias import ( ...@@ -8,6 +8,10 @@ from .backward_activation_bias import (
BackwardActivationBias, BackwardActivationBias,
fuse_backward_activation_bias, fuse_backward_activation_bias,
) )
from .backward_add_rmsnorm import (
BackwardAddRMSNorm,
fuse_backward_add_rmsnorm,
)
from .backward_linear_add import ( from .backward_linear_add import (
BackwardLinearAdd, BackwardLinearAdd,
fuse_backward_linear_add, fuse_backward_linear_add,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused backward RMNorm + add."""
from __future__ import annotations
from typing import Optional
import math
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.ops.basic import MakeExtraOutput, RMSNorm
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data
from .._common import maybe_dequantize
class BackwardAddRMSNorm(FusedOperation):
"""Fused backward RMNorm + add"""
def __init__(self, *, add: MakeExtraOutput, rmsnorm: RMSNorm):
super().__init__((add, rmsnorm))
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
list[tuple[Optional[torch.Tensor], ...]],
list[tuple[()]],
]:
# Get basic operations
rmsnorm_op = self.basic_ops[1]
rmsnorm_op_ctx = basic_op_ctxs[0]
# Saved tensors from forward pass
x, rstdevs = rmsnorm_op_ctx.saved_tensors
# Tensor dims
weight_dims = rmsnorm_op.weight.size()
inner_dim = math.prod(weight_dims)
# Check input tensors
dtype = rmsnorm_op_ctx.dtype
extra_grad = basic_op_grad_extra_outputs[1][0]
dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size())
w = maybe_dequantize(rmsnorm_op.weight, dtype).view((inner_dim,))
add = maybe_dequantize(extra_grad.contiguous(), dtype).view(x.size())
# Compute RMSNorm backward pass
dx, dw = tex.rmsnorm_bwd_add(
dy,
x,
add,
rstdevs,
w,
rmsnorm_op._sm_margins["backward"],
rmsnorm_op.zero_centered_gamma,
)
# Clear saved tensors if possible
clear_tensor_data(x)
clear_tensor_data(rstdevs)
# Reshape results
grad_input = dx.view(grad_output.size())
grad_weight = dw.view(weight_dims)
return grad_input, [(grad_weight,), ()], [(), ()]
def fuse_backward_add_rmsnorm(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward RMNorm + add
Parameters
----------
ops: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, RMSNorm):
continue
# Check if second op is "make extra output"
op, _ = ops[0]
if not isinstance(op, MakeExtraOutput):
continue
if op._in_place:
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardAddRMSNorm(
rmsnorm=window[0][0],
add=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
...@@ -9,13 +9,10 @@ from typing import Optional ...@@ -9,13 +9,10 @@ from typing import Optional
import torch import torch
from transformer_engine.pytorch.ops.basic import BasicLinear, MakeExtraOutput from ...module.base import get_dummy_wgrad
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data from ...utils import clear_tensor_data
from ..basic import BasicLinear, MakeExtraOutput
from ..op import FusedOperation, FusibleOperation, OperationContext
class BackwardLinearAdd(FusedOperation): class BackwardLinearAdd(FusedOperation):
...@@ -53,20 +50,22 @@ class BackwardLinearAdd(FusedOperation): ...@@ -53,20 +50,22 @@ class BackwardLinearAdd(FusedOperation):
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion # Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"): weight_param = linear_op.weight
linear_op.weight.main_grad = linear_op.weight.get_main_grad() if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
"accumulate_into_main_grad=True, " "accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute" "but weight parameter does not have main_grad attribute"
) )
grad_weight = linear_op.weight.main_grad.detach() grad_weight = weight_param.main_grad.detach()
else: else:
accumulate_into_main_grad = False accumulate_into_main_grad = False
...@@ -92,12 +91,23 @@ class BackwardLinearAdd(FusedOperation): ...@@ -92,12 +91,23 @@ class BackwardLinearAdd(FusedOperation):
grad_output_quantizer=linear_op_ctx.grad_output_quantizer, grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=linear_op_ctx.grad_input_quantizer, grad_input_quantizer=linear_op_ctx.grad_input_quantizer,
) )
if accumulate_into_main_grad:
grad_weight = None
# Clear input tensor if possible # Clear input tensor if possible
clear_tensor_data(x_local) clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [(grad_weight,), ()], [(), ()] return grad_input, [(grad_weight,), ()], [(), ()]
......
...@@ -9,13 +9,10 @@ from typing import Optional ...@@ -9,13 +9,10 @@ from typing import Optional
import torch import torch
from ..basic import BasicLinear, ConstantScale from ...module.base import get_dummy_wgrad
from ..op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data from ...utils import clear_tensor_data
from ..basic import BasicLinear, ConstantScale
from ..op import FusedOperation, FusibleOperation, OperationContext
class BackwardLinearScale(FusedOperation): class BackwardLinearScale(FusedOperation):
...@@ -54,20 +51,22 @@ class BackwardLinearScale(FusedOperation): ...@@ -54,20 +51,22 @@ class BackwardLinearScale(FusedOperation):
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion # Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"): weight_param = linear_op.weight
linear_op.weight.main_grad = linear_op.weight.get_main_grad() if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
"accumulate_into_main_grad=True, " "accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute" "but weight parameter does not have main_grad attribute"
) )
grad_weight = linear_op.weight.main_grad.detach() grad_weight = weight_param.main_grad.detach()
else: else:
accumulate_into_main_grad = False accumulate_into_main_grad = False
...@@ -92,12 +91,23 @@ class BackwardLinearScale(FusedOperation): ...@@ -92,12 +91,23 @@ class BackwardLinearScale(FusedOperation):
grad_output_quantizer=linear_op_ctx.grad_output_quantizer, grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=linear_op_ctx.grad_input_quantizer, grad_input_quantizer=linear_op_ctx.grad_input_quantizer,
) )
if accumulate_into_main_grad:
grad_weight = None
# Clear input tensor if possible # Clear input tensor if possible
clear_tensor_data(x_local) clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [(), (grad_weight,)], [(), ()] return grad_input, [(), (grad_weight,)], [(), ()]
......
...@@ -10,14 +10,11 @@ from typing import Any, Optional ...@@ -10,14 +10,11 @@ from typing import Any, Optional
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from transformer_engine.pytorch.ops.basic import BasicLinear, Bias from ...fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...tensor import Quantizer from ...tensor import Quantizer
from ..basic import BasicLinear, Bias
from ..op import FusedOperation, FusibleOperation, OperationContext
class ForwardLinearBiasActivation(FusedOperation): class ForwardLinearBiasActivation(FusedOperation):
...@@ -121,6 +118,8 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -121,6 +118,8 @@ class ForwardLinearBiasActivation(FusedOperation):
# Save state for backward pass # Save state for backward pass
if linear_op_ctx.requires_grad: if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.input_quantizer = input_quantizer
......
...@@ -10,14 +10,11 @@ from typing import Any, Optional ...@@ -10,14 +10,11 @@ from typing import Any, Optional
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from transformer_engine.pytorch.ops.basic import AddExtraInput, BasicLinear, Bias from ...fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.op import ( from ...tensor import Quantizer
FusedOperation, from ..basic import AddExtraInput, BasicLinear, Bias
FusibleOperation, from ..op import FusedOperation, FusibleOperation, OperationContext
OperationContext,
)
from transformer_engine.pytorch.tensor import Quantizer
class ForwardLinearBiasAdd(FusedOperation): class ForwardLinearBiasAdd(FusedOperation):
...@@ -118,6 +115,8 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -118,6 +115,8 @@ class ForwardLinearBiasAdd(FusedOperation):
# Save state for backward pass # Save state for backward pass
if linear_op_ctx.requires_grad: if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.input_quantizer = input_quantizer
......
...@@ -10,14 +10,15 @@ from typing import Any, Optional ...@@ -10,14 +10,15 @@ from typing import Any, Optional
import torch import torch
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...fp8 import FP8GlobalStateManager from ...fp8 import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, ConstantScale from ..basic import AddExtraInput, BasicLinear, ConstantScale
from ..op import ( from ..op import (
FusedOperation, FusedOperation,
FusibleOperation, FusibleOperation,
OperationContext, OperationContext,
) )
from ...tensor import Quantizer
class ForwardLinearScaleAdd(FusedOperation): class ForwardLinearScaleAdd(FusedOperation):
...@@ -95,6 +96,8 @@ class ForwardLinearScaleAdd(FusedOperation): ...@@ -95,6 +96,8 @@ class ForwardLinearScaleAdd(FusedOperation):
# Save state for backward pass # Save state for backward pass
if linear_op_ctx.requires_grad: if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.input_quantizer = input_quantizer
......
...@@ -14,11 +14,12 @@ from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_exter ...@@ -14,11 +14,12 @@ from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_exter
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...distributed import get_distributed_world_size from ...distributed import get_distributed_world_size
from ...module.base import ( from ...module.base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad,
get_ub, get_ub,
get_workspace, get_workspace,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
) )
from ...tensor.quantized_tensor import Quantizer from ...tensor.quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer
...@@ -240,16 +241,16 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -240,16 +241,16 @@ class UserbuffersBackwardLinear(FusedOperation):
with_dgrad_all_gather_x = False with_dgrad_all_gather_x = False
with_wgrad_reduce_scatter_dx = False with_wgrad_reduce_scatter_dx = False
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute)
ub_type_dgrad = CommOverlapType.AG ub_type_dgrad = CommOverlapType.AG
with_dgrad_all_gather_dy = True with_dgrad_all_gather_dy = True
elif tensor_parallel_mode == "column": elif tensor_parallel_mode == "column":
if input_requires_grad and weight_requires_grad: if input_requires_grad and weight_requires_grad:
with_bulk_overlap = True with_bulk_overlap = True
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute)
ub_type_dgrad = CommOverlapType.AG ub_type_dgrad = CommOverlapType.AG
with_dgrad_all_gather_x = True with_dgrad_all_gather_x = True
ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad") ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad", with_quantized_compute)
ub_type_wgrad = CommOverlapType.RS ub_type_wgrad = CommOverlapType.RS
with_wgrad_reduce_scatter_dx = True with_wgrad_reduce_scatter_dx = True
if ub_comm_wgrad.is_fp8_ubuf(): if ub_comm_wgrad.is_fp8_ubuf():
...@@ -257,7 +258,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -257,7 +258,7 @@ class UserbuffersBackwardLinear(FusedOperation):
"Userbuffers reduce-scatter is not supported with FP8 buffers" "Userbuffers reduce-scatter is not supported with FP8 buffers"
) )
else: else:
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute)
ub_type_dgrad = CommOverlapType.RS ub_type_dgrad = CommOverlapType.RS
with_dgrad_reduce_scatter_dx = True with_dgrad_reduce_scatter_dx = True
if ub_comm_dgrad.is_fp8_ubuf(): if ub_comm_dgrad.is_fp8_ubuf():
...@@ -408,7 +409,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -408,7 +409,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get the communication stream from the dgrad GEMM to use for the AG # Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = ub_comm_dgrad.get_communication_stream() dgrad_send_stream, dgrad_recv_stream = ub_comm_dgrad.get_communication_stream()
ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad") ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad", with_quantized_compute)
grad_output_quantizer.set_usage(rowwise=False, columnwise=True) grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
...@@ -513,20 +514,22 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -513,20 +514,22 @@ class UserbuffersBackwardLinear(FusedOperation):
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion # Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"): weight_param = linear_op.weight
linear_op.weight.main_grad = linear_op.weight.get_main_grad() if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
"accumulate_into_main_grad=True, " "accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute" "but weight parameter does not have main_grad attribute"
) )
grad_weight = linear_op.weight.main_grad.detach() grad_weight = weight_param.main_grad.detach()
else: else:
accumulate_into_main_grad = False accumulate_into_main_grad = False
...@@ -558,10 +561,21 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -558,10 +561,21 @@ class UserbuffersBackwardLinear(FusedOperation):
# Clear input tensor if possible # Clear input tensor if possible
clear_tensor_data(x_local) clear_tensor_data(x_local)
# Return gradients # Megatron-LM wgrad fusion
grad_params = [() for _ in range(len(self.basic_ops))] # Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad: if accumulate_into_main_grad:
grad_weight = None grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
# Return gradients
grad_params = [() for _ in range(len(self.basic_ops))]
grad_params[self._op_idxs["linear"]] = (grad_weight,) grad_params[self._op_idxs["linear"]] = (grad_weight,)
if bias_op is not None: if bias_op is not None:
grad_params[self._op_idxs["bias"]] = (grad_bias,) grad_params[self._op_idxs["bias"]] = (grad_bias,)
......
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
from transformer_engine_torch import CommOverlapType from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import get_distributed_world_size from ...distributed import get_distributed_world_size
from ...fp8 import FP8GlobalStateManager from ...fp8 import FP8GlobalStateManager
from ...module.base import ( from ...module.base import (
...@@ -189,7 +190,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -189,7 +190,7 @@ class UserbuffersForwardLinear(FusedOperation):
output_quantizer = None output_quantizer = None
# Get Userbuffers communicator # Get Userbuffers communicator
ub_comm = get_ub(ub_comm_name + "_fprop") ub_comm = get_ub(ub_comm_name + "_fprop", with_quantized_compute)
with_ub_all_gather = tensor_parallel_mode == "column" with_ub_all_gather = tensor_parallel_mode == "column"
with_ub_reduce_scatter = tensor_parallel_mode == "row" with_ub_reduce_scatter = tensor_parallel_mode == "row"
ub_type = CommOverlapType.AG if with_ub_all_gather else CommOverlapType.RS ub_type = CommOverlapType.AG if with_ub_all_gather else CommOverlapType.RS
...@@ -353,6 +354,8 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -353,6 +354,8 @@ class UserbuffersForwardLinear(FusedOperation):
# Save state for backward pass # Save state for backward pass
if linear_op_ctx.requires_grad: if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.input_quantizer = input_quantizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment