Unverified Commit c4cacbaa authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[v1] reduce graph capture time for piecewise cudagraph (#10059)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 0c63c34f
import copy import copy
import dataclasses import dataclasses
import operator import operator
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from unittest.mock import patch
import torch import torch
import torch.fx as fx import torch.fx as fx
...@@ -503,17 +505,29 @@ class PiecewiseBackend: ...@@ -503,17 +505,29 @@ class PiecewiseBackend:
entry.input_addresses = input_addresses entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph() cudagraph = torch.cuda.CUDAGraph()
# mind-exploding: carefully manage the reference and memory. with ExitStack() as stack:
with torch.cuda.graph(cudagraph, pool=self.graph_pool): if not self.is_first_graph:
# `output` is managed by pytorch's cudagraph pool # during every model forward, we will capture
output = entry.runnable(*args) # many pieces of cudagraphs (roughly one per layer).
if self.is_last_graph: # running gc again and again across layers will
# by converting it to weak ref, # make the cudagraph capture very slow.
# the original `output` will immediately be released # therefore, we only run gc for the first graph,
# to save memory. It is only safe to do this for # and disable gc for the rest of the graphs.
# the last graph, because the output of the last graph stack.enter_context(patch("gc.collect", lambda: None))
# will not be used by any other cuda graph. stack.enter_context(
output = weak_ref_tensors(output) patch("torch.cuda.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
if self.is_last_graph:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the output # here we always use weak ref for the output
# to save memory # to save memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment