Unverified Commit f69e45be authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Don't use autograd hook for bwd reduction (#781)



Don't use autograd hook for bwd reduction
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5d34b2ac
......@@ -81,8 +81,6 @@ class FP8GlobalStateManager:
fp8_tensors_recompute_buffer = []
fp8_available = None
reason_for_no_fp8 = ""
multi_grad_hook_tensors = []
bwd_amax_update_hook_registered = False
autocast_arguments = {}
autocast_to_fp8_params = {}
fp8_param_to_autocast = {}
......@@ -106,8 +104,6 @@ class FP8GlobalStateManager:
cls.fp8_tensors_recompute_buffer = []
cls.fp8_available = None
cls.reason_for_no_fp8 = ""
cls.multi_grad_hook_tensors = []
cls.bwd_amax_update_hook_registered = False
cls.autocast_arguments = {}
cls.autocast_to_fp8_params = {}
cls.fp8_param_to_autocast = {}
......@@ -370,16 +366,6 @@ class FP8GlobalStateManager:
_amax_and_scale_update(
amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe)
@classmethod
def add_tensor_for_bwd_reduction_multi_grad_hook(cls, tensor):
"""Add tensor to list for multi grad hook."""
cls.multi_grad_hook_tensors.append(tensor)
@classmethod
def hook_for_bwd_amax_reduction(cls, grads: Tuple[torch.Tensor]) -> None: # pylint: disable=unused-argument
"""Executes at the end of backward pass."""
cls.reduce_and_update_fp8_tensors(forward=False)
@classmethod
def get_unique_autocast_key(
cls,
......@@ -407,13 +393,6 @@ class FP8GlobalStateManager:
autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group)
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
if not cls.bwd_amax_update_hook_registered and len(cls.multi_grad_hook_tensors) > 0:
# This hook does not fire for graphed modules.
torch.autograd.graph.register_multi_grad_hook(
tuple(cls.multi_grad_hook_tensors), cls.hook_for_bwd_amax_reduction)
cls.bwd_amax_update_hook_registered = True
cls.FP8_ENABLED = enabled
cls.FP8_CALIBRATION = calibrating
cls.FP8_RECIPE = fp8_recipe
......
......@@ -40,6 +40,7 @@ from ..distributed import (
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor
......@@ -89,7 +90,6 @@ class _LayerNormLinear(torch.autograd.Function):
ub_overlap_rs_dgrad: bool,
ub_overlap_ag: bool,
ub_name: str,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -328,6 +328,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
......@@ -660,6 +661,9 @@ class _LayerNormLinear(torch.autograd.Function):
else:
wgrad = None
if ctx.is_first_module and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
......@@ -696,7 +700,6 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1001,10 +1004,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
warnings.warn(
......@@ -1176,7 +1175,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_overlap_rs_dgrad,
self.ub_overlap_ag,
self.ub_name,
self.dummy_tensor,
)
out = fwd_fn(*args)
......
......@@ -49,7 +49,7 @@ from .. import cpp_extensions as tex
from ..constants import dist_group_type, TE_DType
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
from ._common import _apply_normalization
......@@ -121,7 +121,6 @@ class _LayerNormMLP(torch.autograd.Function):
ub_overlap_rs: bool,
ub_overlap_ag: bool,
gemm_gelu_fusion: bool,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -545,6 +544,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
# Row Parallel Linear
if ub_overlap_rs:
......@@ -1121,6 +1121,9 @@ class _LayerNormMLP(torch.autograd.Function):
else:
fc2_wgrad = None
if ctx.is_first_module and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
......@@ -1165,7 +1168,6 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1429,10 +1431,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
warnings.warn(
......@@ -1588,7 +1586,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_overlap_rs,
self.ub_overlap_ag,
self.gemm_gelu_fusion,
self.dummy_tensor,
)
out = fwd_fn(*args)
......
......@@ -43,7 +43,7 @@ from ..cpp_extensions import (
)
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
__all__ = ["Linear"]
......@@ -81,7 +81,6 @@ class _Linear(torch.autograd.Function):
ub_overlap_rs: bool,
ub_overlap_ag: bool,
ub_name: str,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
......@@ -321,6 +320,7 @@ class _Linear(torch.autograd.Function):
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
# Row Parallel Linear
if ub_overlap_rs:
......@@ -530,6 +530,9 @@ class _Linear(torch.autograd.Function):
else:
wgrad = None
if ctx.is_first_module and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return (
wgrad,
None,
......@@ -555,7 +558,6 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -798,10 +800,6 @@ class Linear(TransformerEngineBaseModule):
else:
self.gemm_bias_unfused_add = False
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
......@@ -941,7 +939,6 @@ class Linear(TransformerEngineBaseModule):
self.ub_overlap_rs,
self.ub_overlap_ag,
self.ub_name,
self.dummy_tensor,
)
out = linear_fn(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment