Unverified Commit aff1fd81 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[torch.compile] use interpreter with stable api from pytorch (#9889)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 4581d2cc
...@@ -243,6 +243,65 @@ def split_graph(graph: fx.GraphModule, ...@@ -243,6 +243,65 @@ def split_graph(graph: fx.GraphModule,
return split_gm, outputs return split_gm, outputs
# we share the global graph pool among all the backends
global_graph_pool = None
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
compilation configs.
"""
def __init__(self, module: torch.fx.GraphModule,
compile_submod_names: List[str],
compilation_configs: CompilationConfig, graph_pool):
super().__init__(module)
from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.compilation_configs = compilation_configs
self.graph_pool = graph_pool
self.have_seen_first_graph = False
def run(self, *args):
fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in args
]
return super().run(*fake_args)
def call_module(self, target: torch.fx.node.Target,
args: Tuple[torch.fx.node.Argument,
...], kwargs: Dict[str, Any]) -> Any:
assert isinstance(target, str)
output = super().call_module(target, args, kwargs)
if target in self.compile_submod_names:
submod = self.fetch_attr(target)
sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
compiled_graph_for_general_shape = wrap_inductor(
submod,
args,
self.compilation_configs.inductor_compile_config,
runtime_shape=None,
do_logging=not self.have_seen_first_graph,
use_inductor=self.compilation_configs.use_inductor)
self.module.__dict__[target] = PiecewiseBackend(
submod, self.compilation_configs, self.graph_pool,
not self.have_seen_first_graph, sym_shape_indices,
compiled_graph_for_general_shape)
self.have_seen_first_graph = True
compilation_counter.num_piecewise_capturable_graphs_seen += 1
return output
class VllmBackend: class VllmBackend:
"""The compilation backend for `torch.compile` with VLLM. """The compilation backend for `torch.compile` with VLLM.
It is used for compilation level of `CompilationLevel.PIECEWISE`, It is used for compilation level of `CompilationLevel.PIECEWISE`,
...@@ -263,8 +322,14 @@ class VllmBackend: ...@@ -263,8 +322,14 @@ class VllmBackend:
returned_callable: Callable returned_callable: Callable
def __init__(self, ): def __init__(self, ):
# every instance of VllmBackend has its own graph pool global global_graph_pool
self.graph_pool = torch.cuda.graph_pool_handle() if global_graph_pool is None:
global_graph_pool = torch.cuda.graph_pool_handle()
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = global_graph_pool
# `torch.compile` is JIT compiled, so we don't need to # `torch.compile` is JIT compiled, so we don't need to
# do anything here # do anything here
...@@ -286,55 +351,26 @@ class VllmBackend: ...@@ -286,55 +351,26 @@ class VllmBackend:
self.split_gm, self.piecewise_graphs = split_graph( self.split_gm, self.piecewise_graphs = split_graph(
graph, self.compilation_configs.non_cudagraph_ops) graph, self.compilation_configs.non_cudagraph_ops)
returned_callable: Callable # type: ignore from torch._dynamo.utils import lazy_format_graph_code
logger.debug("%s",
lazy_format_graph_code("stiching module", self.split_gm))
if len(self.piecewise_graphs) == 0: compilation_counter.num_piecewise_graphs_seen += len(
compilation_counter.num_piecewise_graphs_seen += 1 self.piecewise_graphs)
compilation_counter.num_piecewise_capturable_graphs_seen += 1 submod_names_to_compile = [
returned_callable = PiecewiseBackend(graph, item.submod_name for item in self.piecewise_graphs
self.compilation_configs, if not item.is_splitting_graph
self.graph_pool, ]
is_first_graph=True)
else: # propagate the split graph to the piecewise backend,
from torch._dynamo.utils import lazy_format_graph_code # compile submodules with symbolic shapes
logger.debug( PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
"%s", lazy_format_graph_code("stiching module", self.split_gm)) self.compilation_configs,
self.graph_pool).run(*example_inputs)
is_first_graph = True
for item in self.piecewise_graphs:
compilation_counter.num_piecewise_graphs_seen += 1
compilation_counter.num_piecewise_capturable_graphs_seen += not item.is_splitting_graph # noqa
if not item.is_splitting_graph:
# cannot setattr to a module, so we need to set
# the attribute in the __dict__
self.split_gm.__dict__[
item.submod_name] = PiecewiseBackend(
item.graph, self.compilation_configs,
self.graph_pool, is_first_graph)
is_first_graph = False
returned_callable = self.split_gm
self.returned_callable = returned_callable
# trigger the first compilation
# code borrowed from https://github.com/pytorch/pytorch/blob/4e3e08b71171fa34172b2362ff668553fac75f27/torch/_dynamo/backends/distributed.py#L206 # noqa
# to turn the inputs into fake tensors
import torch._guards
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(example_inputs)
fake_args = []
for arg in example_inputs:
if isinstance(arg, torch.Tensor) and not isinstance(
arg, torch._subclasses.FakeTensor):
fake_args.append(
torch._dynamo.utils.to_fake_tensor(arg, fake_mode))
else:
fake_args.append(arg)
self.returned_callable(*fake_args)
self._called = True self._called = True
return self.returned_callable return self.split_gm
@dataclasses.dataclass @dataclasses.dataclass
...@@ -352,11 +388,10 @@ class ConcreteSizeEntry: ...@@ -352,11 +388,10 @@ class ConcreteSizeEntry:
class PiecewiseBackend: class PiecewiseBackend:
def __init__(self, def __init__(self, graph: fx.GraphModule,
graph: fx.GraphModule, compilation_configs: CompilationConfig, graph_pool: Any,
compilation_configs: CompilationConfig, is_first_graph: bool, sym_shape_indices: List[int],
graph_pool: Any, compiled_graph_for_general_shape: Callable):
is_first_graph: bool = False):
""" """
The backend for piecewise compilation. The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing. It mainly handles the compilation and cudagraph capturing.
...@@ -381,12 +416,11 @@ class PiecewiseBackend: ...@@ -381,12 +416,11 @@ class PiecewiseBackend:
self.compilation_configs.capture_sizes self.compilation_configs.capture_sizes
) if self.compilation_configs.use_cudagraph else set() ) if self.compilation_configs.use_cudagraph else set()
self.compile_finished = False
self.first_run_finished = False self.first_run_finished = False
self.compiled_graph_for_general_shape: Callable = None # type: ignore self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
self.sym_shape_indices: List[int] = [] self.sym_shape_indices = sym_shape_indices
# the entries for different shapes that we need to either # the entries for different shapes that we need to either
# compile or capture cudagraph # compile or capture cudagraph
...@@ -399,27 +433,6 @@ class PiecewiseBackend: ...@@ -399,27 +433,6 @@ class PiecewiseBackend:
) )
def __call__(self, *args) -> Any: def __call__(self, *args) -> Any:
if not self.compile_finished:
self.compile_finished = True
# this is the first compilation, we will compile a graph with
# dynamic shape, as the caller will mark first dimension as dynamic
self.sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
self.compiled_graph_for_general_shape = wrap_inductor(
self.graph,
args,
self.compilation_configs.inductor_compile_config,
runtime_shape=None,
do_logging=self.is_first_graph,
use_inductor=self.compilation_configs.use_inductor)
return self.graph(*args)
if not self.first_run_finished: if not self.first_run_finished:
self.first_run_finished = True self.first_run_finished = True
return self.compiled_graph_for_general_shape(*args) return self.compiled_graph_for_general_shape(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment