Unverified Commit 728c558b authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Add pool argument to make_graphed_callable (#1218)



Add pool argument to make_graphed_callable
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7b152a83
......@@ -61,6 +61,7 @@ def _make_graphed_callables(
fp8_weight_caching: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
_order: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None,
) -> SingleOrTuple[Callable]:
"""
Helper method for `make_graphed_callables`
......@@ -193,7 +194,7 @@ def _make_graphed_callables(
fwd_graph.register_generator_state(state)
bwd_graph.register_generator_state(state)
mempool = graph_pool_handle()
mempool = graph_pool_handle() if pool is None else pool
# Warmup
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
......@@ -518,6 +519,7 @@ def make_graphed_callables(
fp8_recipe: Optional[DelayedScaling] = None,
fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None,
) -> Union[Callable, Tuple[Callable, ...]]:
"""
Make CUDA graph version of Transformer Engine modules
......@@ -541,6 +543,9 @@ def make_graphed_callables(
and outputs are disconnected in compute graph.
sample_kwargs: (tuple of) dict, optional
Keyword arguments to callable(s)
pool: (tuple of) int, default = `None`, optional
An instance returned from function `torch.cuda.graph_pool_handle` that hints
this graph may share memory with the indicated pool.
FP8-related parameters
----------------------
......@@ -617,6 +622,7 @@ def make_graphed_callables(
fp8_weight_caching=fp8_weight_caching,
sample_kwargs=sample_kwargs,
_order=_order,
pool=pool,
)
# Ensures warmup does not affect numerics for ops such as dropout.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment