Unverified Commit c4cacbaa authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[v1] reduce graph capture time for piecewise cudagraph (#10059)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 0c63c34f
import copy
import dataclasses
import operator
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from unittest.mock import patch
import torch
import torch.fx as fx
......@@ -503,6 +505,18 @@ class PiecewiseBackend:
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
with ExitStack() as stack:
if not self.is_first_graph:
# during every model forward, we will capture
# many pieces of cudagraphs (roughly one per layer).
# running gc again and again across layers will
# make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment