Unverified Commit a132ac49 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Fix cuda graph capture for grouped gemm (#1345)



* retain_graph=True for grouped gemm
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove an unnecessary retain_graph=True
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make retain_graph in graph capture configurable
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
parent 60ce21f4
...@@ -64,6 +64,7 @@ def _make_graphed_callables( ...@@ -64,6 +64,7 @@ def _make_graphed_callables(
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
_order: Optional[List[int]] = None, _order: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None, pool: Optional[Tuple[int, ...]] = None,
retain_graph_in_backward: bool = False,
) -> SingleOrTuple[Callable]: ) -> SingleOrTuple[Callable]:
""" """
Helper method for `make_graphed_callables` Helper method for `make_graphed_callables`
...@@ -320,6 +321,7 @@ def _make_graphed_callables( ...@@ -320,6 +321,7 @@ def _make_graphed_callables(
grad_outputs=tuple(o for o in static_grad_outputs if o is not None), grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True, only_inputs=True,
allow_unused=allow_unused_input, allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward,
) )
# Constructs a tuple suitable for returning from Graphed.backward: # Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs # Pads out the actually-needed grads with Nones in gradient slots for inputs
...@@ -371,6 +373,7 @@ def _make_graphed_callables( ...@@ -371,6 +373,7 @@ def _make_graphed_callables(
grad_outputs=tuple(o for o in static_grad_outputs if o is not None), grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True, only_inputs=True,
allow_unused=allow_unused_input, allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward,
) )
# Constructs a tuple suitable for returning from Graphed.backward: # Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs that # Pads out the actually-needed grads with Nones in gradient slots for inputs that
...@@ -606,6 +609,7 @@ def make_graphed_callables( ...@@ -606,6 +609,7 @@ def make_graphed_callables(
fp8_weight_caching: bool = False, fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None, _order: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None, pool: Optional[Tuple[int, ...]] = None,
retain_graph_in_backward: bool = False,
) -> Union[Callable, Tuple[Callable, ...]]: ) -> Union[Callable, Tuple[Callable, ...]]:
""" """
Make CUDA graph version of Transformer Engine modules Make CUDA graph version of Transformer Engine modules
...@@ -632,6 +636,8 @@ def make_graphed_callables( ...@@ -632,6 +636,8 @@ def make_graphed_callables(
pool: (tuple of) int, default = `None`, optional pool: (tuple of) int, default = `None`, optional
An instance returned from function `torch.cuda.graph_pool_handle` that hints An instance returned from function `torch.cuda.graph_pool_handle` that hints
this graph may share memory with the indicated pool. this graph may share memory with the indicated pool.
retain_graph_in_backward: bool, default = `False`
Whether to set retain_graph=True in backward graph capture.
FP8-related parameters FP8-related parameters
---------------------- ----------------------
...@@ -716,6 +722,7 @@ def make_graphed_callables( ...@@ -716,6 +722,7 @@ def make_graphed_callables(
sample_kwargs=sample_kwargs, sample_kwargs=sample_kwargs,
_order=_order, _order=_order,
pool=pool, pool=pool,
retain_graph_in_backward=retain_graph_in_backward,
) )
# Ensures warmup does not affect numerics for ops such as dropout. # Ensures warmup does not affect numerics for ops such as dropout.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment