Unverified Commit 728c558b authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Add pool argument to make_graphed_callable (#1218)



Add pool argument to make_graphed_callable
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7b152a83
...@@ -61,6 +61,7 @@ def _make_graphed_callables( ...@@ -61,6 +61,7 @@ def _make_graphed_callables(
fp8_weight_caching: bool = False, fp8_weight_caching: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
_order: Optional[List[int]] = None, _order: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None,
) -> SingleOrTuple[Callable]: ) -> SingleOrTuple[Callable]:
""" """
Helper method for `make_graphed_callables` Helper method for `make_graphed_callables`
...@@ -193,7 +194,7 @@ def _make_graphed_callables( ...@@ -193,7 +194,7 @@ def _make_graphed_callables(
fwd_graph.register_generator_state(state) fwd_graph.register_generator_state(state)
bwd_graph.register_generator_state(state) bwd_graph.register_generator_state(state)
mempool = graph_pool_handle() mempool = graph_pool_handle() if pool is None else pool
# Warmup # Warmup
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
...@@ -518,6 +519,7 @@ def make_graphed_callables( ...@@ -518,6 +519,7 @@ def make_graphed_callables(
fp8_recipe: Optional[DelayedScaling] = None, fp8_recipe: Optional[DelayedScaling] = None,
fp8_weight_caching: bool = False, fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None, _order: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None,
) -> Union[Callable, Tuple[Callable, ...]]: ) -> Union[Callable, Tuple[Callable, ...]]:
""" """
Make CUDA graph version of Transformer Engine modules Make CUDA graph version of Transformer Engine modules
...@@ -541,6 +543,9 @@ def make_graphed_callables( ...@@ -541,6 +543,9 @@ def make_graphed_callables(
and outputs are disconnected in compute graph. and outputs are disconnected in compute graph.
sample_kwargs: (tuple of) dict, optional sample_kwargs: (tuple of) dict, optional
Keyword arguments to callable(s) Keyword arguments to callable(s)
pool: (tuple of) int, default = `None`, optional
An instance returned from function `torch.cuda.graph_pool_handle` that hints
this graph may share memory with the indicated pool.
FP8-related parameters FP8-related parameters
---------------------- ----------------------
...@@ -617,6 +622,7 @@ def make_graphed_callables( ...@@ -617,6 +622,7 @@ def make_graphed_callables(
fp8_weight_caching=fp8_weight_caching, fp8_weight_caching=fp8_weight_caching,
sample_kwargs=sample_kwargs, sample_kwargs=sample_kwargs,
_order=_order, _order=_order,
pool=pool,
) )
# Ensures warmup does not affect numerics for ops such as dropout. # Ensures warmup does not affect numerics for ops such as dropout.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment