Unverified Commit 1f214290 authored by DevByteAI's avatar DevByteAI Committed by GitHub
Browse files

fix(compile): apply partition wrapper when loading AOT cached functions (#31536)


Signed-off-by: default avatarDevbyteai <abud6673@gmail.com>
Signed-off-by: default avatarDevByteAI <161969603+devbyteai@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 8cbdc7eb
...@@ -5,6 +5,7 @@ import functools ...@@ -5,6 +5,7 @@ import functools
import multiprocessing import multiprocessing
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path
import pytest import pytest
import torch import torch
...@@ -24,6 +25,13 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer ...@@ -24,6 +25,13 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
@pytest.fixture
def vllm_tmp_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
"""Fixture that sets VLLM_CACHE_ROOT to a temporary directory."""
monkeypatch.setenv("VLLM_CACHE_ROOT", str(tmp_path / "vllm_cache"))
return tmp_path
def reference_fn(x: torch.Tensor): def reference_fn(x: torch.Tensor):
assert x.shape[0] <= 42 assert x.shape[0] <= 42
assert x.shape[0] % 2 == 0 assert x.shape[0] % 2 == 0
...@@ -148,6 +156,93 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch): ...@@ -148,6 +156,93 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_partition_wrapper_applied_on_aot_load(
monkeypatch: pytest.MonkeyPatch, vllm_tmp_cache: Path, mocker
):
"""
Test that partition wrappers are applied when loading AOT cached functions.
This test verifies the fix for GitHub issue #31439 where AOT compile
caused 2x latency regression when use_inductor_graph_partition=True.
The root cause was that partition wrapper context was bypassed when
loading from AOT cache.
"""
from vllm.config import CUDAGraphMode
args = (torch.randn(10, 10),)
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
# Create config with partition enabled
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
# First compilation - save to cache
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args)
disable_envs_cache()
# Second run - load from cache, verify partition wrapper applied
monkeypatch.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
# Use mocker to spy on set_customized_partition_wrappers
spy = mocker.spy(torch._inductor.utils, "set_customized_partition_wrappers")
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
# First call after restart: loads from AOT cache.
# This tests the fix for the first call after a restart.
compiled_mod(*args)
# Verify partition wrapper was called on AOT load.
assert spy.call_count >= 2, (
"Expected partition wrapper to be set and cleared on AOT load, "
f"got {spy.call_count} calls"
)
# First call should set a wrapper, last call should clear it
assert spy.call_args_list[0][0][0] is not None, (
"First call on AOT load should set a wrapper function"
)
assert spy.call_args_list[-1][0][0] is None, (
"Last call on AOT load should clear the wrapper"
)
# Reset for the next check.
spy.reset_mock()
# Subsequent call: uses the cached `aot_compiled_fn`.
# This tests the fix for subsequent calls.
compiled_mod(*args)
# Verify partition wrapper was called on the subsequent call.
assert spy.call_count >= 2, (
"Expected partition wrapper set and cleared on subsequent "
f"call, got {spy.call_count} calls"
)
assert spy.call_args_list[0][0][0] is not None, (
"First call on subsequent call should set a wrapper function"
)
assert spy.call_args_list[-1][0][0] is None, (
"Last call on subsequent call should clear the wrapper"
)
@pytest.mark.skipif( @pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
) )
......
...@@ -371,8 +371,11 @@ def _support_torch_compile( ...@@ -371,8 +371,11 @@ def _support_torch_compile(
if self.do_not_compile or torch.compiler.is_compiling(): if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
# if aot_compiled_fn is set, just call it. # if aot_compiled_fn is set, call it with partition wrapper context.
# The partition wrapper must be active at runtime for CUDA graph
# capture to work correctly with inductor graph partitioning.
if getattr(self, "aot_compiled_fn", None) is not None: if getattr(self, "aot_compiled_fn", None) is not None:
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
return self.aot_compiled_fn(self, *args, **kwargs) return self.aot_compiled_fn(self, *args, **kwargs)
ds_type = self.compilation_config.dynamic_shapes_config.type ds_type = self.compilation_config.dynamic_shapes_config.type
...@@ -432,6 +435,8 @@ def _support_torch_compile( ...@@ -432,6 +435,8 @@ def _support_torch_compile(
logger.info( logger.info(
"Directly load AOT compilation from path %s", aot_compilation_path "Directly load AOT compilation from path %s", aot_compilation_path
) )
# Apply partition wrapper context for proper CUDA graph capture
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
return self.aot_compiled_fn(self, *args, **kwargs) return self.aot_compiled_fn(self, *args, **kwargs)
if self.compiled: if self.compiled:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment