Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aff1fd81
Unverified
Commit
aff1fd81
authored
Nov 01, 2024
by
youkaichao
Committed by
GitHub
Nov 01, 2024
Browse files
[torch.compile] use interpreter with stable api from pytorch (#9889)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
4581d2cc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
89 additions
and
76 deletions
+89
-76
vllm/compilation/backends.py
vllm/compilation/backends.py
+89
-76
No files found.
vllm/compilation/backends.py
View file @
aff1fd81
...
...
@@ -243,6 +243,65 @@ def split_graph(graph: fx.GraphModule,
return
split_gm
,
outputs
# we share the global graph pool among all the backends
global_graph_pool
=
None
class
PiecewiseCompileInterpreter
(
torch
.
fx
.
Interpreter
):
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
compilation configs.
"""
def
__init__
(
self
,
module
:
torch
.
fx
.
GraphModule
,
compile_submod_names
:
List
[
str
],
compilation_configs
:
CompilationConfig
,
graph_pool
):
super
().
__init__
(
module
)
from
torch._guards
import
detect_fake_mode
self
.
fake_mode
=
detect_fake_mode
()
self
.
compile_submod_names
=
compile_submod_names
self
.
compilation_configs
=
compilation_configs
self
.
graph_pool
=
graph_pool
self
.
have_seen_first_graph
=
False
def
run
(
self
,
*
args
):
fake_args
=
[
self
.
fake_mode
.
from_tensor
(
t
)
if
isinstance
(
t
,
torch
.
Tensor
)
else
t
for
t
in
args
]
return
super
().
run
(
*
fake_args
)
def
call_module
(
self
,
target
:
torch
.
fx
.
node
.
Target
,
args
:
Tuple
[
torch
.
fx
.
node
.
Argument
,
...],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
assert
isinstance
(
target
,
str
)
output
=
super
().
call_module
(
target
,
args
,
kwargs
)
if
target
in
self
.
compile_submod_names
:
submod
=
self
.
fetch_attr
(
target
)
sym_shape_indices
=
[
i
for
i
,
x
in
enumerate
(
args
)
if
isinstance
(
x
,
torch
.
SymInt
)
]
compiled_graph_for_general_shape
=
wrap_inductor
(
submod
,
args
,
self
.
compilation_configs
.
inductor_compile_config
,
runtime_shape
=
None
,
do_logging
=
not
self
.
have_seen_first_graph
,
use_inductor
=
self
.
compilation_configs
.
use_inductor
)
self
.
module
.
__dict__
[
target
]
=
PiecewiseBackend
(
submod
,
self
.
compilation_configs
,
self
.
graph_pool
,
not
self
.
have_seen_first_graph
,
sym_shape_indices
,
compiled_graph_for_general_shape
)
self
.
have_seen_first_graph
=
True
compilation_counter
.
num_piecewise_capturable_graphs_seen
+=
1
return
output
class
VllmBackend
:
"""The compilation backend for `torch.compile` with VLLM.
It is used for compilation level of `CompilationLevel.PIECEWISE`,
...
...
@@ -263,8 +322,14 @@ class VllmBackend:
returned_callable
:
Callable
def
__init__
(
self
,
):
# every instance of VllmBackend has its own graph pool
self
.
graph_pool
=
torch
.
cuda
.
graph_pool_handle
()
global
global_graph_pool
if
global_graph_pool
is
None
:
global_graph_pool
=
torch
.
cuda
.
graph_pool_handle
()
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self
.
graph_pool
=
global_graph_pool
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
...
...
@@ -286,55 +351,26 @@ class VllmBackend:
self
.
split_gm
,
self
.
piecewise_graphs
=
split_graph
(
graph
,
self
.
compilation_configs
.
non_cudagraph_ops
)
returned_callable
:
Callable
# type: ignore
from
torch._dynamo.utils
import
lazy_format_graph_code
logger
.
debug
(
"%s"
,
lazy_format_graph_code
(
"stiching module"
,
self
.
split_gm
))
if
len
(
self
.
piecewise_graphs
)
==
0
:
compilation_counter
.
num_piecewise_graphs_seen
+=
1
compilation_counter
.
num_piecewise_capturable_graphs_seen
+=
1
returned_callable
=
PiecewiseBackend
(
graph
,
self
.
compilation_configs
,
self
.
graph_pool
,
is_first_graph
=
True
)
else
:
from
torch._dynamo.utils
import
lazy_format_graph_code
logger
.
debug
(
"%s"
,
lazy_format_graph_code
(
"stiching module"
,
self
.
split_gm
))
is_first_graph
=
True
for
item
in
self
.
piecewise_graphs
:
compilation_counter
.
num_piecewise_graphs_seen
+=
1
compilation_counter
.
num_piecewise_capturable_graphs_seen
+=
not
item
.
is_splitting_graph
# noqa
if
not
item
.
is_splitting_graph
:
# cannot setattr to a module, so we need to set
# the attribute in the __dict__
self
.
split_gm
.
__dict__
[
item
.
submod_name
]
=
PiecewiseBackend
(
item
.
graph
,
self
.
compilation_configs
,
self
.
graph_pool
,
is_first_graph
)
is_first_graph
=
False
returned_callable
=
self
.
split_gm
self
.
returned_callable
=
returned_callable
# trigger the first compilation
# code borrowed from https://github.com/pytorch/pytorch/blob/4e3e08b71171fa34172b2362ff668553fac75f27/torch/_dynamo/backends/distributed.py#L206 # noqa
# to turn the inputs into fake tensors
import
torch._guards
from
torch._guards
import
detect_fake_mode
fake_mode
=
detect_fake_mode
(
example_inputs
)
fake_args
=
[]
for
arg
in
example_inputs
:
if
isinstance
(
arg
,
torch
.
Tensor
)
and
not
isinstance
(
arg
,
torch
.
_subclasses
.
FakeTensor
):
fake_args
.
append
(
torch
.
_dynamo
.
utils
.
to_fake_tensor
(
arg
,
fake_mode
))
else
:
fake_args
.
append
(
arg
)
self
.
returned_callable
(
*
fake_args
)
compilation_counter
.
num_piecewise_graphs_seen
+=
len
(
self
.
piecewise_graphs
)
submod_names_to_compile
=
[
item
.
submod_name
for
item
in
self
.
piecewise_graphs
if
not
item
.
is_splitting_graph
]
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter
(
self
.
split_gm
,
submod_names_to_compile
,
self
.
compilation_configs
,
self
.
graph_pool
).
run
(
*
example_inputs
)
self
.
_called
=
True
return
self
.
returned_callable
return
self
.
split_gm
@
dataclasses
.
dataclass
...
...
@@ -352,11 +388,10 @@ class ConcreteSizeEntry:
class
PiecewiseBackend
:
def
__init__
(
self
,
graph
:
fx
.
GraphModule
,
compilation_configs
:
CompilationConfig
,
graph_pool
:
Any
,
is_first_graph
:
bool
=
False
):
def
__init__
(
self
,
graph
:
fx
.
GraphModule
,
compilation_configs
:
CompilationConfig
,
graph_pool
:
Any
,
is_first_graph
:
bool
,
sym_shape_indices
:
List
[
int
],
compiled_graph_for_general_shape
:
Callable
):
"""
The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing.
...
...
@@ -381,12 +416,11 @@ class PiecewiseBackend:
self
.
compilation_configs
.
capture_sizes
)
if
self
.
compilation_configs
.
use_cudagraph
else
set
()
self
.
compile_finished
=
False
self
.
first_run_finished
=
False
self
.
compiled_graph_for_general_shape
:
Callable
=
None
# type: ignore
self
.
compiled_graph_for_general_shape
=
compiled_graph_for_general_shape
# noqa
self
.
sym_shape_indices
:
List
[
int
]
=
[]
self
.
sym_shape_indices
=
sym_shape_indices
# the entries for different shapes that we need to either
# compile or capture cudagraph
...
...
@@ -399,27 +433,6 @@ class PiecewiseBackend:
)
def
__call__
(
self
,
*
args
)
->
Any
:
if
not
self
.
compile_finished
:
self
.
compile_finished
=
True
# this is the first compilation, we will compile a graph with
# dynamic shape, as the caller will mark first dimension as dynamic
self
.
sym_shape_indices
=
[
i
for
i
,
x
in
enumerate
(
args
)
if
isinstance
(
x
,
torch
.
SymInt
)
]
self
.
compiled_graph_for_general_shape
=
wrap_inductor
(
self
.
graph
,
args
,
self
.
compilation_configs
.
inductor_compile_config
,
runtime_shape
=
None
,
do_logging
=
self
.
is_first_graph
,
use_inductor
=
self
.
compilation_configs
.
use_inductor
)
return
self
.
graph
(
*
args
)
if
not
self
.
first_run_finished
:
self
.
first_run_finished
=
True
return
self
.
compiled_graph_for_general_shape
(
*
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment