Unverified Commit 2894e493 authored by Pingtian Li's avatar Pingtian Li Committed by GitHub
Browse files

[Pytorch] Add get_backward_dw_params api for TE module (#2614)



* add grad reduce api for cuda graph hook
Signed-off-by: default avatarPingtian Li <pingtianl@nvidia.com>

* fix code consistency
Signed-off-by: default avatarPingtian Li <pingtianl@nvidia.com>

---------
Signed-off-by: default avatarPingtian Li <pingtianl@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent b8412430
......@@ -853,12 +853,22 @@ def _make_graphed_callables(
return functionalized
def make_graphed_attribute_functions(graph_idx):
# Get te modules for current graph
te_modules = visited_te_modules.get(graph_idx, set())
# Attach backward_dw as an attribute to the graphed callable.
def backward_dw():
if need_bwd_dw_graph.get(graph_idx, False):
bwd_dw_graphs[graph_idx].replay()
# Trigger the grad accumulation hook for wgrad graphs.
for module in te_modules:
if (
isinstance(module, TransformerEngineBaseModule)
and module.need_backward_dw()
):
module._trigger_wgrad_accumulation_and_reduce_hooks()
# Attach reset as an attribute to the graphed callable.
def reset():
fwd_graphs[graph_idx].reset()
......
......@@ -1526,6 +1526,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
bias_tensor.grad = bgrad.to(bias_tensor.dtype)
del wgrad
del bgrad
self._trigger_wgrad_accumulation_and_reduce_hooks()
def _trigger_wgrad_accumulation_and_reduce_hooks(self):
"""
Trigger the wgrad accumulation and reduce hooks.
"""
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
......
......@@ -873,8 +873,7 @@ class GroupedLinear(TransformerEngineBaseModule):
del grad_biases_
del wgrad_list
del tensor_list
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
self._trigger_wgrad_accumulation_and_reduce_hooks()
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
......
......@@ -2506,5 +2506,4 @@ class LayerNormMLP(TransformerEngineBaseModule):
del fc2_wgrad
del fc1_wgrad
del fc1_bias_grad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
self._trigger_wgrad_accumulation_and_reduce_hooks()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment