"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "00328ac79387c418f598a45d56a5a5d14abae153"
Unverified Commit 6273cede authored by buptzyb's avatar buptzyb Committed by GitHub
Browse files

[PyTorch] Support delay_wgrad_compute cudagraph (#1948)



* support cudagraph dw
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* fix lint
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* fix ci
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 021e1e62
...@@ -322,14 +322,16 @@ def _make_graphed_callables( ...@@ -322,14 +322,16 @@ def _make_graphed_callables(
fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
bwd_dw_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
graph_callables = [None for _ in range(len(flatten_sample_args))] graph_callables = [None for _ in range(len(flatten_sample_args))]
# For cases with multiple active RNG states, e.g. TP. # For cases with multiple active RNG states, e.g. TP.
if graph_safe_rng_available(): if graph_safe_rng_available():
for _, state in get_all_rng_states().items(): for _, state in get_all_rng_states().items():
for fwd_graph, bwd_graph in zip(fwd_graphs, bwd_graphs): for fwd_graph, bwd_graph, bwd_dw_graph in zip(fwd_graphs, bwd_graphs, bwd_dw_graphs):
fwd_graph.register_generator_state(state) fwd_graph.register_generator_state(state)
bwd_graph.register_generator_state(state) bwd_graph.register_generator_state(state)
bwd_dw_graph.register_generator_state(state)
mempool = graph_pool_handle() if pool is None else pool mempool = graph_pool_handle() if pool is None else pool
...@@ -366,21 +368,8 @@ def _make_graphed_callables( ...@@ -366,21 +368,8 @@ def _make_graphed_callables(
), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique."
# Filter the TE modules that cudagraph can access. # Filter the TE modules that cudagraph can access.
visited_te_modules = set() visited_te_modules = {}
need_bwd_dw_graph = {}
def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
if isinstance(module, TransformerEngineBaseModule):
visited_te_modules.add(module)
# If forward is called on a BasicOperation directly the hook will run
elif isinstance(module, BasicOperation):
visited_te_modules.add(module)
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
elif isinstance(module, Sequential):
assert module._module_groups is not None, "Should have been initialized by warmup"
for module_group in module._module_groups:
if isinstance(module_group, OperationFuser):
for basic_op in module_group._basic_ops:
visited_te_modules.add(basic_op)
# Run warmup and do the above filtering. # Run warmup and do the above filtering.
with torch.cuda.stream(torch.cuda.Stream()): with torch.cuda.stream(torch.cuda.Stream()):
...@@ -388,6 +377,31 @@ def _make_graphed_callables( ...@@ -388,6 +377,31 @@ def _make_graphed_callables(
args = sample_args[func_idx] args = sample_args[func_idx]
kwargs = sample_kwargs[func_idx] kwargs = sample_kwargs[func_idx]
static_input_surface = per_callable_static_input_surfaces[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx]
def hook_fn(
module, inputs, outputs, func_idx=func_idx
): # pylint: disable=unused-argument
modules = set()
if isinstance(module, TransformerEngineBaseModule):
modules.add(module)
# If forward is called on a BasicOperation directly the hook will run
elif isinstance(module, BasicOperation):
modules.add(module)
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
elif isinstance(module, Sequential):
assert (
module._module_groups is not None
), "Should have been initialized by warmup"
for module_group in module._module_groups:
if isinstance(module_group, OperationFuser):
for basic_op in module_group._basic_ops:
modules.add(basic_op)
if modules:
if func_idx not in visited_te_modules:
visited_te_modules[func_idx] = modules
else:
visited_te_modules[func_idx].update(modules)
for warmup_iter in range(num_warmup_iters): for warmup_iter in range(num_warmup_iters):
hooks = [] hooks = []
for module in func.modules(): for module in func.modules():
...@@ -432,6 +446,15 @@ def _make_graphed_callables( ...@@ -432,6 +446,15 @@ def _make_graphed_callables(
module_params_with_grad module_params_with_grad
) )
per_callable_static_input_surfaces[func_idx] = static_input_surface per_callable_static_input_surfaces[func_idx] = static_input_surface
# Run wgrad. This is essential for some TE modules when they have
# delay_wgrad_compute enabled.
need_backward_dw = False
for module in visited_te_modules.get(func_idx, set()):
if hasattr(module, "need_backward_dw") and module.need_backward_dw():
need_backward_dw = True
module.backward_dw()
need_bwd_dw_graph[func_idx] = need_backward_dw
else: else:
grad_inputs = None grad_inputs = None
del outputs, grad_inputs del outputs, grad_inputs
...@@ -514,6 +537,17 @@ def _make_graphed_callables( ...@@ -514,6 +537,17 @@ def _make_graphed_callables(
allow_unused=allow_unused_input, allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward, retain_graph=retain_graph_in_backward,
) )
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
# So skip capturing it.
if need_bwd_dw_graph[per_callable_bwd_idx]:
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[per_callable_bwd_idx]:
if (
hasattr(module, "need_backward_dw")
and module.need_backward_dw()
):
module.backward_dw()
# Constructs a tuple suitable for returning from Graphed.backward: # Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs # Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern. # that don't require grad. I couldn't think of a one-liner for this pattern.
...@@ -582,10 +616,12 @@ def _make_graphed_callables( ...@@ -582,10 +616,12 @@ def _make_graphed_callables(
# Capture backward graphs in reverse order # Capture backward graphs in reverse order
per_callable_static_grad_outputs = [] per_callable_static_grad_outputs = []
per_callable_static_grad_inputs = [] per_callable_static_grad_inputs = []
for static_input_surface, static_outputs, bwd_graph in zip( for static_input_surface, static_outputs, bwd_graph, bwd_dw_graph, bwd_idx in zip(
reversed(per_callable_static_input_surfaces), reversed(per_callable_static_input_surfaces),
reversed(per_callable_static_outputs), reversed(per_callable_static_outputs),
reversed(bwd_graphs), reversed(bwd_graphs),
reversed(bwd_dw_graphs),
reversed(range(len(per_callable_static_input_surfaces))),
): ):
# For now, assumes all static_outputs require grad # For now, assumes all static_outputs require grad
static_grad_outputs = tuple( static_grad_outputs = tuple(
...@@ -601,6 +637,11 @@ def _make_graphed_callables( ...@@ -601,6 +637,11 @@ def _make_graphed_callables(
allow_unused=allow_unused_input, allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward, retain_graph=retain_graph_in_backward,
) )
if need_bwd_dw_graph[bwd_idx]:
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[bwd_idx]:
if hasattr(module, "need_backward_dw") and module.need_backward_dw():
module.backward_dw()
# Constructs a tuple suitable for returning from Graphed.backward: # Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs that # Pads out the actually-needed grads with Nones in gradient slots for inputs that
# don't require grad. I couldn't think of a slick one-liner for this pattern. # don't require grad. I couldn't think of a slick one-liner for this pattern.
...@@ -732,9 +773,10 @@ def _make_graphed_callables( ...@@ -732,9 +773,10 @@ def _make_graphed_callables(
) )
func = graph_callables[i] func = graph_callables[i]
te_modules = visited_te_modules.get(i, set())
if isinstance(func, torch.nn.Module): if isinstance(func, torch.nn.Module):
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): def make_graphed_forward(func, graph_training_state, graphed, orig_fwd, te_modules):
def new_fwd(*user_args, **user_kwargs): def new_fwd(*user_args, **user_kwargs):
# If the module's training-or-eval state matches what we graphed, # If the module's training-or-eval state matches what we graphed,
# run the graph, otherwise run the original forward method # run the graph, otherwise run the original forward method
...@@ -743,7 +785,7 @@ def _make_graphed_callables( ...@@ -743,7 +785,7 @@ def _make_graphed_callables(
if FP8GlobalStateManager.is_fp8_enabled(): if FP8GlobalStateManager.is_fp8_enabled():
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
for m in func.modules(): for m in func.modules():
if m not in visited_te_modules: if m not in te_modules:
# Only Set the FP8 meta for the modules included by forward # Only Set the FP8 meta for the modules included by forward
continue continue
if isinstance(m, TransformerEngineBaseModule): if isinstance(m, TransformerEngineBaseModule):
...@@ -780,7 +822,7 @@ def _make_graphed_callables( ...@@ -780,7 +822,7 @@ def _make_graphed_callables(
return new_fwd return new_fwd
forward = make_graphed_forward(func, func.training, graphed, func.forward) forward = make_graphed_forward(func, func.training, graphed, func.forward, te_modules)
if _order is None: if _order is None:
func.forward = forward func.forward = forward
ret.append(func) ret.append(func)
...@@ -789,6 +831,16 @@ def _make_graphed_callables( ...@@ -789,6 +831,16 @@ def _make_graphed_callables(
else: else:
ret.append(graphed) ret.append(graphed)
# Attach backward_dw as an attribute to the graphed callable.
def backward_dw(
need_backward_dw=need_bwd_dw_graph.get(i, False),
bwd_dw_graph=bwd_dw_graphs[i],
):
if need_backward_dw:
bwd_dw_graph.replay()
setattr(ret[-1], "backward_dw", backward_dw)
if just_one_callable: if just_one_callable:
return ret[0] return ret[0]
......
...@@ -662,6 +662,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -662,6 +662,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
self.activation_dtype: Optional[torch.dtype] = None self.activation_dtype: Optional[torch.dtype] = None
self.wgrad_accumulation_and_reduce_hooks = [] self.wgrad_accumulation_and_reduce_hooks = []
self.wgrad_store = None
if not TEDebugState.debug_enabled: if not TEDebugState.debug_enabled:
TEDebugState.initialize() TEDebugState.initialize()
...@@ -1481,12 +1482,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1481,12 +1482,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
""" """
self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook) self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook)
def need_backward_dw(self):
"""
Check if this module needs to execute the delayed weight gradient computation.
This method should be used at the beginning of self.backward_dw() to determine if it
should actually be executed or just return without doing anything.
User can also manually call this method to check that before calling into backward_dw().
"""
return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute()
def backward_dw(self): def backward_dw(self):
""" """
Execute the delayed weight gradient computation. Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients. This method is called after the main backward pass to compute weight gradients.
""" """
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): if not self.need_backward_dw():
return return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, bgrad), _ = self.wgrad_store.pop() (wgrad, bgrad), _ = self.wgrad_store.pop()
......
...@@ -840,7 +840,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -840,7 +840,7 @@ class GroupedLinear(TransformerEngineBaseModule):
Execute the delayed weight gradient computation. Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients. This method is called after the main backward pass to compute weight gradients.
""" """
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): if not self.need_backward_dw():
return return
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop() (_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
......
...@@ -2211,7 +2211,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2211,7 +2211,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
Execute the delayed weight gradient computation. Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients. This method is called after the main backward pass to compute weight gradients.
""" """
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): if not self.need_backward_dw():
return return
with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"): with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"):
(fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop() (fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment