Unverified Commit 2894e493 authored by Pingtian Li's avatar Pingtian Li Committed by GitHub
Browse files

[Pytorch] Add get_backward_dw_params api for TE module (#2614)



* add grad reduce api for cuda graph hook
Signed-off-by: default avatarPingtian Li <pingtianl@nvidia.com>

* fix code consistency
Signed-off-by: default avatarPingtian Li <pingtianl@nvidia.com>

---------
Signed-off-by: default avatarPingtian Li <pingtianl@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent b8412430
...@@ -853,12 +853,22 @@ def _make_graphed_callables( ...@@ -853,12 +853,22 @@ def _make_graphed_callables(
return functionalized return functionalized
def make_graphed_attribute_functions(graph_idx): def make_graphed_attribute_functions(graph_idx):
# Get te modules for current graph
te_modules = visited_te_modules.get(graph_idx, set())
# Attach backward_dw as an attribute to the graphed callable. # Attach backward_dw as an attribute to the graphed callable.
def backward_dw(): def backward_dw():
if need_bwd_dw_graph.get(graph_idx, False): if need_bwd_dw_graph.get(graph_idx, False):
bwd_dw_graphs[graph_idx].replay() bwd_dw_graphs[graph_idx].replay()
# Trigger the grad accumulation hook for wgrad graphs.
for module in te_modules:
if (
isinstance(module, TransformerEngineBaseModule)
and module.need_backward_dw()
):
module._trigger_wgrad_accumulation_and_reduce_hooks()
# Attach reset as an attribute to the graphed callable. # Attach reset as an attribute to the graphed callable.
def reset(): def reset():
fwd_graphs[graph_idx].reset() fwd_graphs[graph_idx].reset()
......
...@@ -1526,8 +1526,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1526,8 +1526,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
bias_tensor.grad = bgrad.to(bias_tensor.dtype) bias_tensor.grad = bgrad.to(bias_tensor.dtype)
del wgrad del wgrad
del bgrad del bgrad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: self._trigger_wgrad_accumulation_and_reduce_hooks()
wgrad_accumulation_and_reduce_hook()
def _trigger_wgrad_accumulation_and_reduce_hooks(self):
"""
Trigger the wgrad accumulation and reduce hooks.
"""
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
def is_debug_iter(self) -> bool: def is_debug_iter(self) -> bool:
""" """
......
...@@ -873,8 +873,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -873,8 +873,7 @@ class GroupedLinear(TransformerEngineBaseModule):
del grad_biases_ del grad_biases_
del wgrad_list del wgrad_list
del tensor_list del tensor_list
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: self._trigger_wgrad_accumulation_and_reduce_hooks()
wgrad_accumulation_and_reduce_hook()
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear.""" """Customize quantizers based on current scaling recipe + linear."""
......
...@@ -2506,5 +2506,4 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2506,5 +2506,4 @@ class LayerNormMLP(TransformerEngineBaseModule):
del fc2_wgrad del fc2_wgrad
del fc1_wgrad del fc1_wgrad
del fc1_bias_grad del fc1_bias_grad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: self._trigger_wgrad_accumulation_and_reduce_hooks()
wgrad_accumulation_and_reduce_hook()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment