Unverified Commit f69e45be authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Don't use autograd hook for bwd reduction (#781)



Don't use autograd hook for bwd reduction
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5d34b2ac
...@@ -81,8 +81,6 @@ class FP8GlobalStateManager: ...@@ -81,8 +81,6 @@ class FP8GlobalStateManager:
fp8_tensors_recompute_buffer = [] fp8_tensors_recompute_buffer = []
fp8_available = None fp8_available = None
reason_for_no_fp8 = "" reason_for_no_fp8 = ""
multi_grad_hook_tensors = []
bwd_amax_update_hook_registered = False
autocast_arguments = {} autocast_arguments = {}
autocast_to_fp8_params = {} autocast_to_fp8_params = {}
fp8_param_to_autocast = {} fp8_param_to_autocast = {}
...@@ -106,8 +104,6 @@ class FP8GlobalStateManager: ...@@ -106,8 +104,6 @@ class FP8GlobalStateManager:
cls.fp8_tensors_recompute_buffer = [] cls.fp8_tensors_recompute_buffer = []
cls.fp8_available = None cls.fp8_available = None
cls.reason_for_no_fp8 = "" cls.reason_for_no_fp8 = ""
cls.multi_grad_hook_tensors = []
cls.bwd_amax_update_hook_registered = False
cls.autocast_arguments = {} cls.autocast_arguments = {}
cls.autocast_to_fp8_params = {} cls.autocast_to_fp8_params = {}
cls.fp8_param_to_autocast = {} cls.fp8_param_to_autocast = {}
...@@ -370,16 +366,6 @@ class FP8GlobalStateManager: ...@@ -370,16 +366,6 @@ class FP8GlobalStateManager:
_amax_and_scale_update( _amax_and_scale_update(
amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe) amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe)
@classmethod
def add_tensor_for_bwd_reduction_multi_grad_hook(cls, tensor):
"""Add tensor to list for multi grad hook."""
cls.multi_grad_hook_tensors.append(tensor)
@classmethod
def hook_for_bwd_amax_reduction(cls, grads: Tuple[torch.Tensor]) -> None: # pylint: disable=unused-argument
"""Executes at the end of backward pass."""
cls.reduce_and_update_fp8_tensors(forward=False)
@classmethod @classmethod
def get_unique_autocast_key( def get_unique_autocast_key(
cls, cls,
...@@ -407,13 +393,6 @@ class FP8GlobalStateManager: ...@@ -407,13 +393,6 @@ class FP8GlobalStateManager:
autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group)
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
if not cls.bwd_amax_update_hook_registered and len(cls.multi_grad_hook_tensors) > 0:
# This hook does not fire for graphed modules.
torch.autograd.graph.register_multi_grad_hook(
tuple(cls.multi_grad_hook_tensors), cls.hook_for_bwd_amax_reduction)
cls.bwd_amax_update_hook_registered = True
cls.FP8_ENABLED = enabled cls.FP8_ENABLED = enabled
cls.FP8_CALIBRATION = calibrating cls.FP8_CALIBRATION = calibrating
cls.FP8_RECIPE = fp8_recipe cls.FP8_RECIPE = fp8_recipe
......
...@@ -40,6 +40,7 @@ from ..distributed import ( ...@@ -40,6 +40,7 @@ from ..distributed import (
) )
from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import _apply_normalization, _noop_cat from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
...@@ -89,7 +90,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -89,7 +90,6 @@ class _LayerNormLinear(torch.autograd.Function):
ub_overlap_rs_dgrad: bool, ub_overlap_rs_dgrad: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
ub_name: str, ub_name: str,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -328,6 +328,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -328,6 +328,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8 ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
...@@ -660,6 +661,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -660,6 +661,9 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
wgrad = None wgrad = None
if ctx.is_first_module and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return ( return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma, dgamma,
...@@ -696,7 +700,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -696,7 +700,6 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -1001,10 +1004,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1001,10 +1004,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
warnings.warn( warnings.warn(
...@@ -1176,7 +1175,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1176,7 +1175,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_overlap_rs_dgrad, self.ub_overlap_rs_dgrad,
self.ub_overlap_ag, self.ub_overlap_ag,
self.ub_name, self.ub_name,
self.dummy_tensor,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -49,7 +49,7 @@ from .. import cpp_extensions as tex ...@@ -49,7 +49,7 @@ from .. import cpp_extensions as tex
from ..constants import dist_group_type, TE_DType from ..constants import dist_group_type, TE_DType
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
from ._common import _apply_normalization from ._common import _apply_normalization
...@@ -121,7 +121,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -121,7 +121,6 @@ class _LayerNormMLP(torch.autograd.Function):
ub_overlap_rs: bool, ub_overlap_rs: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
gemm_gelu_fusion: bool, gemm_gelu_fusion: bool,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -545,6 +544,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -545,6 +544,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8 ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
# Row Parallel Linear # Row Parallel Linear
if ub_overlap_rs: if ub_overlap_rs:
...@@ -1121,6 +1121,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1121,6 +1121,9 @@ class _LayerNormMLP(torch.autograd.Function):
else: else:
fc2_wgrad = None fc2_wgrad = None
if ctx.is_first_module and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return ( return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma, dgamma,
...@@ -1165,7 +1168,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1165,7 +1168,6 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -1429,10 +1431,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1429,10 +1431,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
warnings.warn( warnings.warn(
...@@ -1588,7 +1586,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1588,7 +1586,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_overlap_rs, self.ub_overlap_rs,
self.ub_overlap_ag, self.ub_overlap_ag,
self.gemm_gelu_fusion, self.gemm_gelu_fusion,
self.dummy_tensor,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -43,7 +43,7 @@ from ..cpp_extensions import ( ...@@ -43,7 +43,7 @@ from ..cpp_extensions import (
) )
from ..constants import GemmParallelModes, dist_group_type from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -81,7 +81,6 @@ class _Linear(torch.autograd.Function): ...@@ -81,7 +81,6 @@ class _Linear(torch.autograd.Function):
ub_overlap_rs: bool, ub_overlap_rs: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
ub_name: str, ub_name: str,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] in_features = weight.shape[-1]
...@@ -321,6 +320,7 @@ class _Linear(torch.autograd.Function): ...@@ -321,6 +320,7 @@ class _Linear(torch.autograd.Function):
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.primary_weights_in_fp8 = primary_weights_in_fp8 ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
# Row Parallel Linear # Row Parallel Linear
if ub_overlap_rs: if ub_overlap_rs:
...@@ -530,6 +530,9 @@ class _Linear(torch.autograd.Function): ...@@ -530,6 +530,9 @@ class _Linear(torch.autograd.Function):
else: else:
wgrad = None wgrad = None
if ctx.is_first_module and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return ( return (
wgrad, wgrad,
None, None,
...@@ -555,7 +558,6 @@ class _Linear(torch.autograd.Function): ...@@ -555,7 +558,6 @@ class _Linear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -798,10 +800,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -798,10 +800,6 @@ class Linear(TransformerEngineBaseModule):
else: else:
self.gemm_bias_unfused_add = False self.gemm_bias_unfused_add = False
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_parameters(self, defer_init=False): def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init) super().reset_parameters(defer_init=defer_init)
...@@ -941,7 +939,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -941,7 +939,6 @@ class Linear(TransformerEngineBaseModule):
self.ub_overlap_rs, self.ub_overlap_rs,
self.ub_overlap_ag, self.ub_overlap_ag,
self.ub_name, self.ub_name,
self.dummy_tensor,
) )
out = linear_fn(*args) out = linear_fn(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment