"vscode:/vscode.git/clone" did not exist on "aa06107cbc1cc7378c665809c1608c53070447ea"
Unverified Commit 6273cede authored by buptzyb's avatar buptzyb Committed by GitHub
Browse files

[PyTorch] Support delay_wgrad_compute cudagraph (#1948)



* support cudagraph dw
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* fix lint
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* fix ci
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 021e1e62
......@@ -322,14 +322,16 @@ def _make_graphed_callables(
fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
bwd_dw_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
graph_callables = [None for _ in range(len(flatten_sample_args))]
# For cases with multiple active RNG states, e.g. TP.
if graph_safe_rng_available():
for _, state in get_all_rng_states().items():
for fwd_graph, bwd_graph in zip(fwd_graphs, bwd_graphs):
for fwd_graph, bwd_graph, bwd_dw_graph in zip(fwd_graphs, bwd_graphs, bwd_dw_graphs):
fwd_graph.register_generator_state(state)
bwd_graph.register_generator_state(state)
bwd_dw_graph.register_generator_state(state)
mempool = graph_pool_handle() if pool is None else pool
......@@ -366,21 +368,8 @@ def _make_graphed_callables(
), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique."
# Filter the TE modules that cudagraph can access.
visited_te_modules = set()
def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
if isinstance(module, TransformerEngineBaseModule):
visited_te_modules.add(module)
# If forward is called on a BasicOperation directly the hook will run
elif isinstance(module, BasicOperation):
visited_te_modules.add(module)
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
elif isinstance(module, Sequential):
assert module._module_groups is not None, "Should have been initialized by warmup"
for module_group in module._module_groups:
if isinstance(module_group, OperationFuser):
for basic_op in module_group._basic_ops:
visited_te_modules.add(basic_op)
visited_te_modules = {}
need_bwd_dw_graph = {}
# Run warmup and do the above filtering.
with torch.cuda.stream(torch.cuda.Stream()):
......@@ -388,6 +377,31 @@ def _make_graphed_callables(
args = sample_args[func_idx]
kwargs = sample_kwargs[func_idx]
static_input_surface = per_callable_static_input_surfaces[func_idx]
def hook_fn(
module, inputs, outputs, func_idx=func_idx
): # pylint: disable=unused-argument
modules = set()
if isinstance(module, TransformerEngineBaseModule):
modules.add(module)
# If forward is called on a BasicOperation directly the hook will run
elif isinstance(module, BasicOperation):
modules.add(module)
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
elif isinstance(module, Sequential):
assert (
module._module_groups is not None
), "Should have been initialized by warmup"
for module_group in module._module_groups:
if isinstance(module_group, OperationFuser):
for basic_op in module_group._basic_ops:
modules.add(basic_op)
if modules:
if func_idx not in visited_te_modules:
visited_te_modules[func_idx] = modules
else:
visited_te_modules[func_idx].update(modules)
for warmup_iter in range(num_warmup_iters):
hooks = []
for module in func.modules():
......@@ -432,6 +446,15 @@ def _make_graphed_callables(
module_params_with_grad
)
per_callable_static_input_surfaces[func_idx] = static_input_surface
# Run wgrad. This is essential for some TE modules when they have
# delay_wgrad_compute enabled.
need_backward_dw = False
for module in visited_te_modules.get(func_idx, set()):
if hasattr(module, "need_backward_dw") and module.need_backward_dw():
need_backward_dw = True
module.backward_dw()
need_bwd_dw_graph[func_idx] = need_backward_dw
else:
grad_inputs = None
del outputs, grad_inputs
......@@ -514,6 +537,17 @@ def _make_graphed_callables(
allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward,
)
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
# So skip capturing it.
if need_bwd_dw_graph[per_callable_bwd_idx]:
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[per_callable_bwd_idx]:
if (
hasattr(module, "need_backward_dw")
and module.need_backward_dw()
):
module.backward_dw()
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
......@@ -582,10 +616,12 @@ def _make_graphed_callables(
# Capture backward graphs in reverse order
per_callable_static_grad_outputs = []
per_callable_static_grad_inputs = []
for static_input_surface, static_outputs, bwd_graph in zip(
for static_input_surface, static_outputs, bwd_graph, bwd_dw_graph, bwd_idx in zip(
reversed(per_callable_static_input_surfaces),
reversed(per_callable_static_outputs),
reversed(bwd_graphs),
reversed(bwd_dw_graphs),
reversed(range(len(per_callable_static_input_surfaces))),
):
# For now, assumes all static_outputs require grad
static_grad_outputs = tuple(
......@@ -601,6 +637,11 @@ def _make_graphed_callables(
allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward,
)
if need_bwd_dw_graph[bwd_idx]:
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[bwd_idx]:
if hasattr(module, "need_backward_dw") and module.need_backward_dw():
module.backward_dw()
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs that
# don't require grad. I couldn't think of a slick one-liner for this pattern.
......@@ -732,9 +773,10 @@ def _make_graphed_callables(
)
func = graph_callables[i]
te_modules = visited_te_modules.get(i, set())
if isinstance(func, torch.nn.Module):
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd, te_modules):
def new_fwd(*user_args, **user_kwargs):
# If the module's training-or-eval state matches what we graphed,
# run the graph, otherwise run the original forward method
......@@ -743,7 +785,7 @@ def _make_graphed_callables(
if FP8GlobalStateManager.is_fp8_enabled():
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
for m in func.modules():
if m not in visited_te_modules:
if m not in te_modules:
# Only Set the FP8 meta for the modules included by forward
continue
if isinstance(m, TransformerEngineBaseModule):
......@@ -780,7 +822,7 @@ def _make_graphed_callables(
return new_fwd
forward = make_graphed_forward(func, func.training, graphed, func.forward)
forward = make_graphed_forward(func, func.training, graphed, func.forward, te_modules)
if _order is None:
func.forward = forward
ret.append(func)
......@@ -789,6 +831,16 @@ def _make_graphed_callables(
else:
ret.append(graphed)
# Attach backward_dw as an attribute to the graphed callable.
def backward_dw(
need_backward_dw=need_bwd_dw_graph.get(i, False),
bwd_dw_graph=bwd_dw_graphs[i],
):
if need_backward_dw:
bwd_dw_graph.replay()
setattr(ret[-1], "backward_dw", backward_dw)
if just_one_callable:
return ret[0]
......
......@@ -662,6 +662,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
self.activation_dtype: Optional[torch.dtype] = None
self.wgrad_accumulation_and_reduce_hooks = []
self.wgrad_store = None
if not TEDebugState.debug_enabled:
TEDebugState.initialize()
......@@ -1481,12 +1482,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""
self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook)
def need_backward_dw(self):
"""
Check if this module needs to execute the delayed weight gradient computation.
This method should be used at the beginning of self.backward_dw() to determine if it
should actually be executed or just return without doing anything.
User can also manually call this method to check that before calling into backward_dw().
"""
return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute()
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
if not self.need_backward_dw():
return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, bgrad), _ = self.wgrad_store.pop()
......
......@@ -840,7 +840,7 @@ class GroupedLinear(TransformerEngineBaseModule):
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
if not self.need_backward_dw():
return
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
......
......@@ -2211,7 +2211,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
if not self.need_backward_dw():
return
with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"):
(fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment