Unverified Commit bdd8981d authored by Zhengxu Chen's avatar Zhengxu Chen Committed by GitHub
Browse files

[compile] Apply stored functorch config while finalizing loaded artifacts. (#36582)


Signed-off-by: default avatarzhxchen17 <zhxchen17@fb.com>
parent f088a831
...@@ -369,8 +369,14 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -369,8 +369,14 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
saved_aot_autograd_config = self.aot_autograd_config
if saved_aot_autograd_config is not None:
functorch_ctx = torch._functorch.config.patch(saved_aot_autograd_config)
else:
functorch_ctx = contextlib.nullcontext()
vllm_backend = VllmBackend(vllm_config, self.prefix, self.is_encoder) vllm_backend = VllmBackend(vllm_config, self.prefix, self.is_encoder)
with tracing(TracingContext(self._fake_mode)): with tracing(TracingContext(self._fake_mode)), functorch_ctx:
result = vllm_backend(self.graph_module, list(self.example_inputs)) result = vllm_backend(self.graph_module, list(self.example_inputs))
self.optimized_call = result.optimized_call self.optimized_call = result.optimized_call
self.vllm_backend = vllm_backend self.vllm_backend = vllm_backend
......
...@@ -258,31 +258,15 @@ class PiecewiseBackend: ...@@ -258,31 +258,15 @@ class PiecewiseBackend:
else: else:
args_list = get_fake_args_from_graph(self.graph) args_list = get_fake_args_from_graph(self.graph)
# TODO(https://github.com/vllm-project/vllm/issues/35766) range_entry.runnable = self.vllm_backend.compiler_manager.compile(
# Can we remove strict_autograd_cache and self.graph,
# force_non_lazy_backward_lowering overrides? args_list,
# I added them explicitly because this is what they are self.vllm_backend.inductor_config,
# set to before the refactor self.compilation_config,
# (https://github.com/vllm-project/vllm/pull/35472). compile_range=range_entry.compile_range,
# They affect the aotautograd cache key computation graph_index=self.piecewise_compile_index,
# but they shouldn't have any effect on the actual num_graphs=self.total_piecewise_compiles,
# compilation.
config_patches = dict(
bundled_autograd_cache=True,
strict_autograd_cache=False,
) )
if hasattr(torch._functorch.config, "force_non_lazy_backward_lowering"):
config_patches["force_non_lazy_backward_lowering"] = False
with torch._functorch.config.patch(**config_patches):
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args_list,
self.vllm_backend.inductor_config,
self.compilation_config,
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
)
range_entry.compiled = True range_entry.compiled = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment