Unverified Commit 1fa1e53a authored by Simon Mo's avatar Simon Mo Committed by GitHub
Browse files

Revert "[compile] Initialize passes at VllmBackend init" (#37733)

parent 3ffa5200
...@@ -32,9 +32,9 @@ from vllm.platforms import current_platform ...@@ -32,9 +32,9 @@ from vllm.platforms import current_platform
def test_compile_config_repr_succeeds(): def test_compile_config_repr_succeeds():
# setup: VllmBackend mutates the config object # setup: VllmBackend mutates the config object
# Note: VllmBackend.__init__ already calls configure_post_pass()
config = VllmConfig() config = VllmConfig()
_ = VllmBackend(config) backend = VllmBackend(config)
backend.configure_post_pass()
# test that repr(config) succeeds # test that repr(config) succeeds
val = repr(config) val = repr(config)
......
...@@ -836,18 +836,8 @@ class VllmBackend: ...@@ -836,18 +836,8 @@ class VllmBackend:
# in future we need PostGradPassManager.uuid() to be executed # in future we need PostGradPassManager.uuid() to be executed
# only at compile time. # only at compile time.
self.inductor_config = deepcopy(self.compilation_config.inductor_compile_config) self.inductor_config = deepcopy(self.compilation_config.inductor_compile_config)
# `torch.compile` is JIT compiled, so we don't need to
# Configure post-grad passes (including AllReduceFusionPass) during # do anything here
# backend init rather than at torch.compile time, so that expensive
# one-time setup (e.g. FlashInfer workspace allocation) is not
# attributed to compilation latency.
start = time.time()
self.configure_post_pass()
logger.info_once(
"Post-grad pass configuration time: %.2f s",
time.time() - start,
scope="local",
)
def collect_standalone_compile_artifacts( def collect_standalone_compile_artifacts(
self, self,
...@@ -1128,6 +1118,7 @@ class VllmBackend: ...@@ -1128,6 +1118,7 @@ class VllmBackend:
assert not self._called, "VllmBackend can only be called once" assert not self._called, "VllmBackend can only be called once"
self.graph = graph self.graph = graph
self.configure_post_pass()
if self.compilation_config.use_inductor_graph_partition: if self.compilation_config.use_inductor_graph_partition:
# Let Inductor decide partitioning; avoid FX-level pre-splitting. # Let Inductor decide partitioning; avoid FX-level pre-splitting.
......
...@@ -380,11 +380,6 @@ def _support_torch_compile( ...@@ -380,11 +380,6 @@ def _support_torch_compile(
compilation_counter.num_models_seen += 1 compilation_counter.num_models_seen += 1
self.compiled = False self.compiled = False
# Skip if a parent class's @support_torch_compile already
# initialized the compile wrapper
if hasattr(self, "_compiled_callable"):
return
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class # Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
TorchCompileWithNoGuardsWrapper.__init__( TorchCompileWithNoGuardsWrapper.__init__(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment