Unverified Commit 7c5dedc2 authored by dolpm's avatar dolpm Committed by GitHub
Browse files

[AOT compilation] support torch.compile inductor artifacts in VllmCompiledFunction (#25205)


Signed-off-by: default avatardolpm <34420038+dolpm@users.noreply.github.com>
parent 193069d1
...@@ -2,15 +2,21 @@ ...@@ -2,15 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
import hashlib
import multiprocessing import multiprocessing
import pickle
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from unittest.mock import Mock, patch
import pytest import pytest
import torch import torch
import vllm.model_executor.layers.activation import vllm.model_executor.layers.activation
from vllm.compilation.caching import (
StandaloneCompiledArtifacts,
)
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ( from vllm.config import (
CompilationConfig, CompilationConfig,
...@@ -40,6 +46,15 @@ def reference_fn(x: torch.Tensor): ...@@ -40,6 +46,15 @@ def reference_fn(x: torch.Tensor):
return x return x
def reference_fn_tuple(x: torch.Tensor):
"""Reference function that returns a tuple of tensors."""
assert x.shape[0] <= 42
assert x.shape[0] % 2 == 0
for _ in range(3000):
x = x + x.shape[0]
return x, x * 2
@support_torch_compile @support_torch_compile
class CompiledMod(torch.nn.Module): class CompiledMod(torch.nn.Module):
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -49,10 +64,22 @@ class CompiledMod(torch.nn.Module): ...@@ -49,10 +64,22 @@ class CompiledMod(torch.nn.Module):
return reference_fn(x) return reference_fn(x)
@support_torch_compile
class CompiledModTuple(torch.nn.Module):
"""A compiled module that returns a tuple of tensors."""
def __init__(self, **kwargs):
super().__init__()
def forward(self, x: torch.Tensor):
return reference_fn_tuple(x)
def make_vllm_config() -> VllmConfig: def make_vllm_config() -> VllmConfig:
return VllmConfig( return VllmConfig(
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, mode=CompilationMode.VLLM_COMPILE,
backend="inductor",
) )
) )
...@@ -73,6 +100,8 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch): ...@@ -73,6 +100,8 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
expected = reference_fn(*args) expected = reference_fn(*args)
with use_vllm_config(vllm_config): with use_vllm_config(vllm_config):
m.setenv("VLLM_USE_AOT_COMPILE", "0") m.setenv("VLLM_USE_AOT_COMPILE", "0")
m.setenv("VLLM_USE_MEGA_AOT_ARTIFACT", "1")
m.setenv("VLLM_USE_STANDALONE_COMPILE", "1")
with ( with (
pytest.raises(RuntimeError, match="Detected recompile"), pytest.raises(RuntimeError, match="Detected recompile"),
torch.compiler.set_stance("fail_on_recompile"), torch.compiler.set_stance("fail_on_recompile"),
...@@ -94,6 +123,8 @@ def test_force_aot_load(monkeypatch: pytest.MonkeyPatch): ...@@ -94,6 +123,8 @@ def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m: with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
args = (torch.randn(10, 10),) args = (torch.randn(10, 10),)
m.setenv("VLLM_USE_AOT_COMPILE", "1") m.setenv("VLLM_USE_AOT_COMPILE", "1")
m.setenv("VLLM_USE_MEGA_AOT_ARTIFACT", "1")
m.setenv("VLLM_USE_STANDALONE_COMPILE", "1")
m.setenv("VLLM_FORCE_AOT_LOAD", "1") m.setenv("VLLM_FORCE_AOT_LOAD", "1")
m.setenv("VLLM_CACHE_ROOT", tmpdirname) m.setenv("VLLM_CACHE_ROOT", tmpdirname)
vllm_config = make_vllm_config() vllm_config = make_vllm_config()
...@@ -111,18 +142,158 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch): ...@@ -111,18 +142,158 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname) m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1") m.setenv("VLLM_USE_AOT_COMPILE", "1")
m.setenv("VLLM_USE_MEGA_AOT_ARTIFACT", "1")
m.setenv("VLLM_USE_STANDALONE_COMPILE", "1")
vllm_config = make_vllm_config() vllm_config = make_vllm_config()
with use_vllm_config(vllm_config): with use_vllm_config(vllm_config):
expected = CompiledMod(vllm_config=vllm_config)(*args) compiled_mod = CompiledMod(vllm_config=vllm_config)
expected = compiled_mod(*args)
disable_envs_cache() disable_envs_cache()
m.setenv("VLLM_FORCE_AOT_LOAD", "1") m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config() vllm_config = make_vllm_config()
with use_vllm_config(vllm_config): with use_vllm_config(vllm_config):
ret = CompiledMod(vllm_config=vllm_config)(*args) cached_mod = CompiledMod(vllm_config=vllm_config)
ret = cached_mod(*args)
assert cached_mod.was_aot_compile_fn_loaded_from_disk, (
"Expected was_aot_compile_fn_loaded_from_disk to be True"
)
assert torch.allclose(ret, expected) assert torch.allclose(ret, expected)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_cache_load_returns_tuple_consistency(monkeypatch: pytest.MonkeyPatch):
"""
Test that cache loading correctly handles the returns_tuple logic.
This verifies that when a model returns a single tensor (not a tuple),
the output type is consistent between fresh compilation and cache load.
Without the fix, cached artifacts would return [tensor] instead of tensor.
"""
with monkeypatch.context() as m:
args = (torch.randn(10, 10),)
with tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
m.setenv("VLLM_USE_MEGA_AOT_ARTIFACT", "1")
m.setenv("VLLM_USE_STANDALONE_COMPILE", "1")
vllm_config = make_vllm_config()
# Fresh compilation
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
fresh_result = compiled_mod(*args)
fresh_result_type = type(fresh_result)
# Verify fresh result is a tensor, not a tuple/list
assert isinstance(fresh_result, torch.Tensor), (
f"Fresh compile should return tensor, got {fresh_result_type}"
)
disable_envs_cache()
# Load from cache
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
cached_mod = CompiledMod(vllm_config=vllm_config)
cached_result = cached_mod(*args)
cached_result_type = type(cached_result)
# Verify cache was actually loaded
assert cached_mod.was_aot_compile_fn_loaded_from_disk, (
"Expected was_aot_compile_fn_loaded_from_disk to be True after "
"loading from cache"
)
# Verify cached result has same type as fresh result
assert isinstance(cached_result, torch.Tensor), (
f"Cache load should return tensor, got {cached_result_type}. "
"This indicates the returns_tuple logic is not being applied "
"correctly when loading from cache."
)
# Verify values match
assert torch.allclose(cached_result, fresh_result), (
"Cached result values should match fresh compilation"
)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_cache_load_returns_tuple_consistency_tuple_output(
monkeypatch: pytest.MonkeyPatch,
):
"""
Test that cache loading correctly handles models that return tuples.
This verifies that when a model returns a tuple of tensors, the output
type is preserved as a tuple between fresh compilation and cache load.
"""
with monkeypatch.context() as m:
args = (torch.randn(10, 10),)
with tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
m.setenv("VLLM_USE_MEGA_AOT_ARTIFACT", "1")
m.setenv("VLLM_USE_STANDALONE_COMPILE", "1")
vllm_config = make_vllm_config()
# Fresh compilation with tuple-returning model
with use_vllm_config(vllm_config):
compiled_mod = CompiledModTuple(vllm_config=vllm_config)
fresh_result = compiled_mod(*args)
fresh_result_type = type(fresh_result)
# Verify fresh result is a tuple
assert isinstance(fresh_result, tuple), (
f"Fresh compile should return tuple, got {fresh_result_type}"
)
assert len(fresh_result) == 2, (
f"Fresh compile should return 2-tuple, got {len(fresh_result)}"
)
disable_envs_cache()
# Load from cache
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
cached_mod = CompiledModTuple(vllm_config=vllm_config)
cached_result = cached_mod(*args)
cached_result_type = type(cached_result)
# Verify cache was actually loaded
assert cached_mod.was_aot_compile_fn_loaded_from_disk, (
"Expected was_aot_compile_fn_loaded_from_disk to be True after "
"loading from cache"
)
# Verify cached result is also a tuple
assert isinstance(cached_result, tuple), (
f"Cache load should return tuple, got {cached_result_type}. "
"This indicates the returns_tuple logic is not preserving "
"tuple outputs when loading from cache."
)
assert len(cached_result) == 2, (
f"Cache load should return 2-tuple, got {len(cached_result)}"
)
# Verify values match
assert torch.allclose(cached_result[0], fresh_result[0]), (
"Cached result[0] values should match fresh compilation"
)
assert torch.allclose(cached_result[1], fresh_result[1]), (
"Cached result[1] values should match fresh compilation"
)
@pytest.mark.skipif( @pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
) )
...@@ -137,6 +308,8 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch): ...@@ -137,6 +308,8 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname) m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1") m.setenv("VLLM_USE_AOT_COMPILE", "1")
m.setenv("VLLM_USE_MEGA_AOT_ARTIFACT", "1")
m.setenv("VLLM_USE_STANDALONE_COMPILE", "1")
vllm_config = make_vllm_config() vllm_config = make_vllm_config()
with use_vllm_config(vllm_config): with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config) compiled_mod = CompiledMod(vllm_config=vllm_config)
...@@ -144,6 +317,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch): ...@@ -144,6 +317,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
artifacts = compiled_mod.aot_compiled_fn._artifacts artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards() guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
disable_envs_cache() disable_envs_cache()
m.setenv("VLLM_FORCE_AOT_LOAD", "1") m.setenv("VLLM_FORCE_AOT_LOAD", "1")
...@@ -151,6 +325,9 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch): ...@@ -151,6 +325,9 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
with use_vllm_config(vllm_config): with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config) compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args) compiled_mod(*args)
assert compiled_mod.was_aot_compile_fn_loaded_from_disk, (
"Expected was_aot_compile_fn_loaded_from_disk to be True"
)
artifacts = compiled_mod.aot_compiled_fn._artifacts artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards() guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
...@@ -188,6 +365,7 @@ def test_partition_wrapper_applied_on_aot_load( ...@@ -188,6 +365,7 @@ def test_partition_wrapper_applied_on_aot_load(
with use_vllm_config(vllm_config): with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config) compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args) compiled_mod(*args)
disable_envs_cache() disable_envs_cache()
# Second run - load from cache, verify partition wrapper applied # Second run - load from cache, verify partition wrapper applied
...@@ -210,6 +388,11 @@ def test_partition_wrapper_applied_on_aot_load( ...@@ -210,6 +388,11 @@ def test_partition_wrapper_applied_on_aot_load(
# This tests the fix for the first call after a restart. # This tests the fix for the first call after a restart.
compiled_mod(*args) compiled_mod(*args)
# Verify cache was loaded
assert compiled_mod.was_aot_compile_fn_loaded_from_disk, (
"Expected was_aot_compile_fn_loaded_from_disk to be True"
)
# Verify partition wrapper was called on AOT load. # Verify partition wrapper was called on AOT load.
assert spy.call_count >= 2, ( assert spy.call_count >= 2, (
"Expected partition wrapper to be set and cleared on AOT load, " "Expected partition wrapper to be set and cleared on AOT load, "
...@@ -307,3 +490,233 @@ def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch): ...@@ -307,3 +490,233 @@ def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
finally: finally:
# Restore original method # Restore original method
symbolic_shapes_module.make_symbol = original_make_symbol symbolic_shapes_module.make_symbol = original_make_symbol
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
class TestStandaloneCompiledArtifacts:
def test_init(self):
cache = StandaloneCompiledArtifacts()
assert cache.submodule_bytes == {}
assert cache.submodule_bytes_store == {}
assert cache.loaded_submodule_store == {}
def test_insert_new_artifact(self):
cache = StandaloneCompiledArtifacts()
test_data = b"test_artifact_data"
submod_name = "test_submod"
shape = "s1"
hasher = hashlib.sha256()
hasher.update(test_data)
expected_hash = hasher.hexdigest()
cache.insert(submod_name, shape, test_data)
assert f"{submod_name}_{shape}" in cache.submodule_bytes
assert cache.submodule_bytes[f"{submod_name}_{shape}"] == expected_hash
assert expected_hash in cache.submodule_bytes_store
assert cache.submodule_bytes_store[expected_hash] == test_data
def test_insert_duplicate_artifact(self):
cache = StandaloneCompiledArtifacts()
test_data = b"duplicate_test_data"
submod_name1 = "submod1"
submod_name2 = "submod2"
shape = "s2"
cache.insert(submod_name1, shape, test_data)
cache.insert(submod_name2, shape, test_data)
hash1 = cache.submodule_bytes[f"{submod_name1}_{shape}"]
hash2 = cache.submodule_bytes[f"{submod_name2}_{shape}"]
assert hash1 == hash2
assert len(cache.submodule_bytes_store) == 1
assert len(cache.submodule_bytes) == 2
def test_get_artifact(self):
cache = StandaloneCompiledArtifacts()
test_data = b"retrievable_data"
submod_name = "mod1"
shape = "shape16"
cache.insert(submod_name, shape, test_data)
retrieved_data = cache.get(submod_name, shape)
assert retrieved_data == test_data
def test_get_nonexistent_artifact(self):
cache = StandaloneCompiledArtifacts()
with pytest.raises(KeyError):
cache.get("nonexistent", "shape")
def test_size_bytes(self):
cache = StandaloneCompiledArtifacts()
assert cache.size_bytes() == 0
data1 = b"x" * 100
data2 = b"y" * 200
cache.insert("mod1", "shape1", data1)
cache.insert("mod2", "shape2", data2)
assert cache.size_bytes() == 300
def test_num_artifacts_and_entries(self):
cache = StandaloneCompiledArtifacts()
assert cache.num_artifacts() == 0
assert cache.num_entries() == 0
cache.insert("mod1", "shape1", b"data1")
cache.insert("mod2", "shape2", b"data2")
assert cache.num_artifacts() == 2
assert cache.num_entries() == 2
cache.insert("mod3", "shape3", b"data1")
assert cache.num_artifacts() == 2
assert cache.num_entries() == 3
@patch("torch._inductor.standalone_compile.AOTCompiledArtifact.deserialize")
def test_load_all_success(self, mock_deserialize):
"""Test successful loading of all artifacts"""
cache = StandaloneCompiledArtifacts()
mock_artifact1 = Mock()
mock_artifact2 = Mock()
mock_deserialize.side_effect = [mock_artifact1, mock_artifact2]
cache.insert("mod1", "shape1", pickle.dumps(b"data1"))
cache.insert("mod2", "shape2", pickle.dumps(b"data2"))
cache.load_all()
assert len(cache.loaded_submodule_store) == 2
assert mock_deserialize.call_count == 2
@patch("torch._inductor.standalone_compile.AOTCompiledArtifact.deserialize")
def test_load_all_already_loaded(self, mock_deserialize):
"""Test that load_all skips if already loaded"""
cache = StandaloneCompiledArtifacts()
mock_artifact = Mock()
cache.submodule_bytes_store["hash1"] = pickle.dumps(b"data1")
cache.loaded_submodule_store["hash1"] = mock_artifact
cache.load_all()
mock_deserialize.assert_not_called()
@patch("torch._inductor.standalone_compile.AOTCompiledArtifact.deserialize")
def test_get_loaded_artifact(self, mock_deserialize):
"""Test retrieving loaded artifacts"""
cache = StandaloneCompiledArtifacts()
mock_artifact = Mock()
mock_deserialize.return_value = mock_artifact
submod_name = "test_mod"
shape = "test_shape"
cache.insert(submod_name, shape, pickle.dumps(b"test_data"))
cache.load_all()
retrieved_artifact = cache.get_loaded(submod_name, shape)
assert retrieved_artifact == mock_artifact
def test_getstate_setstate(self):
cache = StandaloneCompiledArtifacts()
cache.insert("mod1", "shape1", b"data1")
cache.insert("mod2", "shape2", b"data2")
cache.loaded_submodule_store["hash1"] = Mock()
state = cache.__getstate__()
assert "submodule_bytes" in state
assert "submodule_bytes_store" in state
assert "loaded_submodule_store" not in state
new_cache = StandaloneCompiledArtifacts()
new_cache.__setstate__(state)
assert new_cache.submodule_bytes == cache.submodule_bytes
assert new_cache.submodule_bytes_store == cache.submodule_bytes_store
assert new_cache.loaded_submodule_store == {}
def test_pickle_roundtrip(self):
cache = StandaloneCompiledArtifacts()
test_data1 = b"pickle_test_data_1"
test_data2 = b"pickle_test_data_2"
cache.insert("mod1", "shape1", test_data1)
cache.insert("mod2", "shape2", test_data2)
pickled_data = pickle.dumps(cache)
restored_cache = pickle.loads(pickled_data)
assert restored_cache.get("mod1", "shape1") == test_data1
assert restored_cache.get("mod2", "shape2") == test_data2
assert restored_cache.num_artifacts() == cache.num_artifacts()
assert restored_cache.num_entries() == cache.num_entries()
assert restored_cache.size_bytes() == cache.size_bytes()
assert len(restored_cache.loaded_submodule_store) == 0
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
class TestStandaloneCompiledArtifactsIntegration:
def test_add_pickle_unpickle(self):
cache = StandaloneCompiledArtifacts()
artifacts = {
("mod1", "shape1"): b"m1s1_artifact",
("mod1", "shape2"): b"m1s2_artifact",
("mod2", "shape1"): b"m2s1_artifact",
("mod2", "shape2"): b"m2s2_artifact",
}
for (submod, shape), data in artifacts.items():
cache.insert(submod, shape, data)
assert cache.num_entries() == 4
assert cache.num_artifacts() == 4
for (submod, shape), expected_data in artifacts.items():
retrieved_data = cache.get(submod, shape)
assert retrieved_data == expected_data
pickled = pickle.dumps(cache)
restored_cache = pickle.loads(pickled)
for (submod, shape), expected_data in artifacts.items():
retrieved_data = restored_cache.get(submod, shape)
assert retrieved_data == expected_data
def test_deduplication(self):
cache = StandaloneCompiledArtifacts()
shared_data = b"shared_artifact_data" * 1000
cache.insert("mod1", "shape1", shared_data)
cache.insert("mod2", "shape1", shared_data)
cache.insert("mod1", "shape2", shared_data)
cache.insert("mod3", "shape3", shared_data)
assert cache.num_entries() == 4
assert cache.num_artifacts() == 1
assert cache.size_bytes() == len(shared_data)
for submod, shape in [
("mod1", "shape1"),
("mod2", "shape1"),
("mod1", "shape2"),
("mod3", "shape3"),
]:
assert cache.get(submod, shape) == shared_data
...@@ -21,6 +21,7 @@ ALLOWED_FILES = { ...@@ -21,6 +21,7 @@ ALLOWED_FILES = {
"vllm/transformers_utils/config.py", "vllm/transformers_utils/config.py",
"vllm/model_executor/models/registry.py", "vllm/model_executor/models/registry.py",
"vllm/compilation/caching.py", "vllm/compilation/caching.py",
"vllm/compilation/piecewise_backend.py",
"vllm/distributed/utils.py", "vllm/distributed/utils.py",
"vllm/distributed/parallel_state.py", "vllm/distributed/parallel_state.py",
"vllm/distributed/device_communicators/all_reduce_utils.py", "vllm/distributed/device_communicators/all_reduce_utils.py",
...@@ -30,6 +31,7 @@ ALLOWED_FILES = { ...@@ -30,6 +31,7 @@ ALLOWED_FILES = {
"tests/multimodal/media/test_base.py", "tests/multimodal/media/test_base.py",
"tests/tokenizers_/test_hf.py", "tests/tokenizers_/test_hf.py",
"tests/utils_/test_hashing.py", "tests/utils_/test_hashing.py",
"tests/compile/test_aot_compile.py",
"benchmarks/kernels/graph_machete_bench.py", "benchmarks/kernels/graph_machete_bench.py",
"benchmarks/kernels/benchmark_lora.py", "benchmarks/kernels/benchmark_lora.py",
"benchmarks/kernels/benchmark_machete.py", "benchmarks/kernels/benchmark_machete.py",
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast import ast
import contextvars
import dataclasses import dataclasses
import hashlib import hashlib
import json import json
...@@ -34,7 +35,6 @@ from vllm.platforms import current_platform ...@@ -34,7 +35,6 @@ from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from .caching import VllmSerializableFunction
from .compiler_interface import ( from .compiler_interface import (
CompilerInterface, CompilerInterface,
EagerAdaptor, EagerAdaptor,
...@@ -49,7 +49,48 @@ from .pass_manager import PostGradPassManager ...@@ -49,7 +49,48 @@ from .pass_manager import PostGradPassManager
logger = init_logger(__name__) logger = init_logger(__name__)
def make_copy_and_call(
sym_tensor_indices: list[int],
input_buffers: list[torch.Tensor | None],
callable_fn: Callable[..., Any],
) -> Callable[..., Any]:
"""Create a wrapper that copies inputs to static buffers before calling.
This is used for cudagraph input copying where we need to copy dynamic
tensors to static buffers before invoking the compiled graph.
Args:
sym_tensor_indices: Indices of tensors with symbolic shapes
input_buffers: List of static buffers (can contain None for lazy init)
callable_fn: The compiled function to call
Returns:
A wrapper function that copies inputs and calls the compiled function
"""
def copy_and_call(*args):
list_args = list(args)
for i, index in enumerate(sym_tensor_indices):
runtime_tensor = list_args[index]
runtime_shape = runtime_tensor.shape[0]
# lazy initialization of buffer on first call
if input_buffers[i] is None:
input_buffers[i] = runtime_tensor.clone()
static_tensor = input_buffers[i][:runtime_shape] # type: ignore[index]
static_tensor.copy_(runtime_tensor)
list_args[index] = static_tensor
return callable_fn(*list_args)
return copy_and_call
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
assert not envs.VLLM_USE_MEGA_AOT_ARTIFACT or envs.VLLM_USE_STANDALONE_COMPILE, (
"VLLM_USE_MEGA_AOT_ARTIFACT=1 requires VLLM_USE_STANDALONE_COMPILE=1"
)
if compilation_config.backend == "inductor": if compilation_config.backend == "inductor":
# Use standalone compile only if requested, version is new enough, # Use standalone compile only if requested, version is new enough,
# and the symbol actually exists in this PyTorch build. # and the symbol actually exists in this PyTorch build.
...@@ -355,6 +396,60 @@ def split_graph( ...@@ -355,6 +396,60 @@ def split_graph(
compilation_start_time = 0.0 compilation_start_time = 0.0
def wrap_with_cudagraph_if_needed(
piecewise_backend: Any,
vllm_config: VllmConfig,
compilation_config: CompilationConfig,
is_first_graph: bool,
is_last_graph: bool,
) -> Any:
"""
Wrap a piecewise backend with CUDA graph wrapper if needed.
This function is shared between VllmBackend and
construct_serializable_fn_from_inductor_cache.
Args:
piecewise_backend: The backend to wrap
vllm_config: The vLLM configuration
compilation_config: The compilation configuration
is_first_graph: Whether this is the first graph in the sequence
is_last_graph: Whether this is the last graph in the sequence
Returns:
The wrapped backend if CUDA graphs are enabled, otherwise the original backend
"""
if (
not compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
or compilation_config.use_inductor_graph_partition
):
return piecewise_backend
# We're using Dynamo-based piecewise splitting, so we wrap
# the whole subgraph with a static graph wrapper.
from .cuda_graph import CUDAGraphOptions
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls()
)
# Always assign PIECEWISE runtime mode to the
# CUDAGraphWrapper for piecewise_backend, to distinguish
# it from the FULL cudagraph runtime mode, no matter it
# is wrapped on a full or piecewise fx graph.
return static_graph_wrapper_class(
runnable=piecewise_backend,
vllm_config=vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=is_first_graph,
gc_disable=not is_first_graph,
weak_ref_output=is_last_graph,
),
)
class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some It runs the given graph with fake inputs, and compile some
...@@ -365,6 +460,18 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] ...@@ -365,6 +460,18 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
it will be used to determine the order of the compiled piecewise it will be used to determine the order of the compiled piecewise
graphs. The first graph will handle logging, and the last graph graphs. The first graph will handle logging, and the last graph
has some special cudagraph output handling. has some special cudagraph output handling.
Note: This class shares similar logic with
reconstruct_serializable_fn_from_mega_artifact in caching.py.
Both create PiecewiseBackend instances and wrap them with cudagraph.
The key difference is:
- reconstruct_serializable_fn_from_mega_artifact: PiecewiseBackend receives
pre-compiled runnables (compiled_runnables is set, graph is None)
- this class: PiecewiseBackend receives the FX graph to compile
(graph is set, compiled_runnables is None)
If modifying the backend creation/wrapping logic, consider updating both.
""" """
def __init__( def __init__(
...@@ -413,6 +520,8 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] ...@@ -413,6 +520,8 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
] ]
# Lazy import here to avoid circular import # Lazy import here to avoid circular import
from torch._inductor.compile_fx import graph_returns_tuple
from .piecewise_backend import PiecewiseBackend from .piecewise_backend import PiecewiseBackend
piecewise_backend = PiecewiseBackend( piecewise_backend = PiecewiseBackend(
...@@ -422,38 +531,16 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] ...@@ -422,38 +531,16 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
len(self.compile_submod_names), len(self.compile_submod_names),
sym_shape_indices, sym_shape_indices,
self.vllm_backend, self.vllm_backend,
graph_returns_tuple(submod),
) )
if ( self.module.__dict__[target] = wrap_with_cudagraph_if_needed(
self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs() piecewise_backend,
and not self.compilation_config.use_inductor_graph_partition self.vllm_config,
): self.compilation_config,
# We're using Dynamo-based piecewise splitting, so we wrap piecewise_backend.is_first_graph,
# the whole subgraph with a static graph wrapper. piecewise_backend.is_last_graph,
from .cuda_graph import CUDAGraphOptions
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls()
)
# Always assign PIECEWISE runtime mode to the
# CUDAGraphWrapper for piecewise_backend, to distinguish
# it from the FULL cudagraph runtime mode, no matter it
# is wrapped on a full or piecewise fx graph.
self.module.__dict__[target] = static_graph_wrapper_class(
runnable=piecewise_backend,
vllm_config=self.vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=piecewise_backend.is_first_graph,
gc_disable=not piecewise_backend.is_first_graph,
weak_ref_output=piecewise_backend.is_last_graph,
),
) )
else:
self.module.__dict__[target] = piecewise_backend
compilation_counter.num_piecewise_capturable_graphs_seen += 1 compilation_counter.num_piecewise_capturable_graphs_seen += 1
...@@ -465,6 +552,21 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] ...@@ -465,6 +552,21 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
model_tag: str = "backbone" model_tag: str = "backbone"
model_is_encoder: bool = False model_is_encoder: bool = False
_on_compilation_complete_callback: contextvars.ContextVar[Callable[[], None] | None] = (
contextvars.ContextVar("on_compilation_complete_callback", default=None)
)
@contextmanager
def set_on_compilation_complete(
callback: Callable[[], None],
) -> Generator[None, None, None]:
token = _on_compilation_complete_callback.set(callback)
try:
yield
finally:
_on_compilation_complete_callback.reset(token)
@contextmanager @contextmanager
def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]: def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]:
...@@ -509,8 +611,6 @@ class VllmBackend: ...@@ -509,8 +611,6 @@ class VllmBackend:
returned_callable: Callable[..., Any] returned_callable: Callable[..., Any]
# Inductor passes to run on the graph pre-defunctionalization # Inductor passes to run on the graph pre-defunctionalization
post_grad_passes: Sequence[Callable[..., Any]] post_grad_passes: Sequence[Callable[..., Any]]
sym_tensor_indices: list[int]
input_buffers: list[torch.Tensor]
compiler_manager: CompilerManager compiler_manager: CompilerManager
# Copy of CompilationConfig.inductor_compile_config + # Copy of CompilationConfig.inductor_compile_config +
# an entry for PostGradPassManager # an entry for PostGradPassManager
...@@ -539,9 +639,6 @@ class VllmBackend: ...@@ -539,9 +639,6 @@ class VllmBackend:
)() )()
self.pass_key = current_platform.pass_key self.pass_key = current_platform.pass_key
self.sym_tensor_indices = []
self.input_buffers = []
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
...@@ -558,6 +655,68 @@ class VllmBackend: ...@@ -558,6 +655,68 @@ class VllmBackend:
# `torch.compile` is JIT compiled, so we don't need to # `torch.compile` is JIT compiled, so we don't need to
# do anything here # do anything here
def collect_standalone_compile_artifacts(
self,
) -> tuple[Any, dict[str, list[int]] | None, dict[str, bool] | None]:
"""Collect inductor cache artifacts from all piecewise backends.
Returns:
tuple: (standalone_compile_artifacts, sym_shape_indices_map,
returns_tuple_map)
- standalone_compile_artifacts: StandaloneCompiledArtifacts
with compiled artifacts
- sym_shape_indices_map: dict mapping submod_name to
sym_shape_indices
- returns_tuple_map: dict mapping submod_name to
returns_tuple
"""
if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
return None, None, None
from .caching import StandaloneCompiledArtifacts
from .piecewise_backend import PiecewiseBackend
standalone_compile_artifacts = StandaloneCompiledArtifacts()
sym_shape_indices_map = {}
returns_tuple_map = {}
for name, _ in self.split_gm.named_children():
# get the actual attribute (shadowed by PiecewiseBackend in __dict__)
child = getattr(self.split_gm, name)
# unwrap the static graph wrapper class if applicable
piecewise_backend = child.runnable if hasattr(child, "runnable") else child
if not isinstance(piecewise_backend, PiecewiseBackend):
continue
submod_name = name
sym_shape_indices_map[submod_name] = piecewise_backend.sym_shape_indices
returns_tuple_map[submod_name] = piecewise_backend.returns_tuple
for shape_str, bytes_data in piecewise_backend.to_bytes().items():
standalone_compile_artifacts.insert(submod_name, shape_str, bytes_data)
logger.debug(
"collected artifact for %s shape %s (%d bytes)",
submod_name,
shape_str,
len(bytes_data),
)
logger.info(
"collected artifacts: %d entries, %d artifacts, %d bytes total",
standalone_compile_artifacts.num_entries(),
standalone_compile_artifacts.num_artifacts(),
standalone_compile_artifacts.size_bytes(),
)
logger.debug(
"standalone compile artifact keys: %s",
list(standalone_compile_artifacts.submodule_bytes.keys()),
)
return standalone_compile_artifacts, sym_shape_indices_map, returns_tuple_map
def configure_post_pass(self) -> None: def configure_post_pass(self) -> None:
self.pass_manager.configure(self.vllm_config) self.pass_manager.configure(self.vllm_config)
...@@ -579,9 +738,11 @@ class VllmBackend: ...@@ -579,9 +738,11 @@ class VllmBackend:
) )
self.inductor_config[self.pass_key] = self.pass_manager self.inductor_config[self.pass_key] = self.pass_manager
def __call__( def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
self, graph: fx.GraphModule, example_inputs: Sequence[Any] from .caching import (
) -> VllmSerializableFunction: VllmSerializableFunction,
)
vllm_config = self.vllm_config vllm_config = self.vllm_config
# Minimal hashing here with existing utilities, reused below. # Minimal hashing here with existing utilities, reused below.
...@@ -721,6 +882,12 @@ class VllmBackend: ...@@ -721,6 +882,12 @@ class VllmBackend:
self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops) self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
# keep a split_gm copy from BEFORE the interpreter replaces
# submodules with PiecewiseBackend -- used for serialization
original_split_gm = None
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
original_split_gm = deepcopy(self.split_gm)
from torch._dynamo.utils import lazy_format_graph_code from torch._dynamo.utils import lazy_format_graph_code
# depyf will hook lazy_format_graph_code and dump the graph # depyf will hook lazy_format_graph_code and dump the graph
...@@ -792,13 +959,21 @@ class VllmBackend: ...@@ -792,13 +959,21 @@ class VllmBackend:
) )
self._called = True self._called = True
graph_to_serialize = (
original_split_gm if envs.VLLM_USE_MEGA_AOT_ARTIFACT else self.graph
)
if ( if (
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
or not self.compilation_config.cudagraph_copy_inputs or not self.compilation_config.cudagraph_copy_inputs
): ):
return VllmSerializableFunction( return VllmSerializableFunction(
graph, example_inputs, self.prefix, self.split_gm, self.is_encoder graph_to_serialize,
example_inputs,
self.prefix,
self.split_gm,
is_encoder=self.is_encoder,
vllm_backend=self,
) )
# index of tensors that have symbolic shapes (batch size) # index of tensors that have symbolic shapes (batch size)
...@@ -806,7 +981,7 @@ class VllmBackend: ...@@ -806,7 +981,7 @@ class VllmBackend:
# symbolic shape only happens for input tensors. # symbolic shape only happens for input tensors.
from torch.fx.experimental.symbolic_shapes import is_symbolic from torch.fx.experimental.symbolic_shapes import is_symbolic
self.sym_tensor_indices = [ sym_tensor_indices = [
i i
for i, x in enumerate(fake_args) for i, x in enumerate(fake_args)
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
...@@ -816,25 +991,18 @@ class VllmBackend: ...@@ -816,25 +991,18 @@ class VllmBackend:
# compiler managed cudagraph input buffers # compiler managed cudagraph input buffers
# we assume the first run with symbolic shapes # we assume the first run with symbolic shapes
# has the maximum size among all the tensors # has the maximum size among all the tensors
self.input_buffers = [ copy_and_call = make_copy_and_call(
example_inputs[x].clone() for x in self.sym_tensor_indices sym_tensor_indices,
] [example_inputs[x].clone() for x in sym_tensor_indices],
self.split_gm,
# this is the callable we return to Dynamo to run )
def copy_and_call(*args: Any) -> Any:
list_args = list(args)
for i, index in enumerate(self.sym_tensor_indices):
runtime_tensor = list_args[index]
runtime_shape = runtime_tensor.shape[0]
static_tensor = self.input_buffers[i][:runtime_shape]
# copy the tensor to the static buffer
static_tensor.copy_(runtime_tensor)
# replace the tensor in the list_args to the static buffer
list_args[index] = static_tensor
return self.split_gm(*list_args)
return VllmSerializableFunction( return VllmSerializableFunction(
graph, example_inputs, self.prefix, copy_and_call, self.is_encoder graph_to_serialize,
example_inputs,
self.prefix,
copy_and_call,
is_encoder=self.is_encoder,
vllm_backend=self,
sym_tensor_indices=sym_tensor_indices,
) )
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import inspect import inspect
import os import os
import pickle import pickle
...@@ -12,6 +13,7 @@ import torch ...@@ -12,6 +13,7 @@ import torch
from torch.utils import _pytree as pytree from torch.utils import _pytree as pytree
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.compiler_interface import get_inductor_factors
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.utils import hash_factors from vllm.config.utils import hash_factors
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -27,6 +29,121 @@ assert isinstance(SerializableCallable, type) ...@@ -27,6 +29,121 @@ assert isinstance(SerializableCallable, type)
logger = init_logger(__name__) logger = init_logger(__name__)
class StandaloneCompiledArtifacts:
"""Storage for standalone compiled artifacts with content-based deduplication.
Deduplication works via a two-level indirection:
1. `submodule_bytes` maps "{submod_name}_{shape}" -> SHA256 hash
2. `submodule_bytes_store` maps SHA256 hash -> actual bytes
When inserting, we compute the SHA256 hash of the bytes. If the hash
already exists in `submodule_bytes_store`, we reuse the existing entry
rather than storing duplicate bytes. This is common because submodules
often compile to identical artifacts (e.g., identical transformer layers
split on attn)
"""
def __init__(self):
# dict from submodule name to byte hash
self.submodule_bytes = {}
# dict from byte hash to bytes
self.submodule_bytes_store = {}
# dict from byte hash to loaded module
self.loaded_submodule_store = {}
def insert(self, submod_name: str, shape: str, entry: bytes):
hasher = hashlib.sha256()
hasher.update(entry)
hex_digest = hasher.hexdigest()
self.submodule_bytes[f"{submod_name}_{shape}"] = hex_digest
if hex_digest not in self.submodule_bytes_store:
self.submodule_bytes_store[hex_digest] = entry
logger.debug(
"inserting new artifact for submod %s with shape %s "
"(%s bytes) at hash %s",
submod_name,
shape,
len(entry),
hex_digest,
)
else:
logger.debug(
"reusing existing cache artifact for submod %s "
"with shape %s (%s bytes) at hash %s",
submod_name,
shape,
len(entry),
hex_digest,
)
def get(self, submod_name: str, shape: str) -> bytes:
logger.debug(
"getting artifact for submod %s with shape %s",
submod_name,
shape,
)
return self.submodule_bytes_store[
self.submodule_bytes[f"{submod_name}_{shape}"]
]
def get_loaded(self, submod_name: str, shape: str):
logger.debug(
"getting artifact for submod %s with shape %s",
submod_name,
shape,
)
return self.loaded_submodule_store[
self.submodule_bytes[f"{submod_name}_{shape}"]
]
def size_bytes(self) -> int:
return sum(len(entry) for entry in self.submodule_bytes_store.values())
def num_artifacts(self) -> int:
return len(self.submodule_bytes_store)
def num_entries(self) -> int:
return len(self.submodule_bytes)
def submodule_names(self) -> list[str]:
# get unique "{submod_name}" from "{submod_name}_{shape}", preserving order
names = [cache_key.rsplit("_", 1)[0] for cache_key in self.submodule_bytes]
return list(dict.fromkeys(names))
def load_all(self) -> None:
import concurrent.futures
# check already loaded
if len(self.loaded_submodule_store) == len(self.submodule_bytes_store):
return
from torch._inductor.standalone_compile import AOTCompiledArtifact
def _load_entry(entry_bytes) -> AOTCompiledArtifact:
entry = pickle.loads(entry_bytes)
return AOTCompiledArtifact.deserialize(entry)
with concurrent.futures.ThreadPoolExecutor() as executor:
entries = list(self.submodule_bytes_store.values())
loaded_entries = list(executor.map(_load_entry, entries))
for i, k in enumerate(self.submodule_bytes_store.keys()):
self.loaded_submodule_store[k] = loaded_entries[i]
logger.debug("loaded all %s submodules", self.num_artifacts())
def __getstate__(self):
return {
"submodule_bytes": self.submodule_bytes,
"submodule_bytes_store": self.submodule_bytes_store,
}
def __setstate__(self, state):
self.submodule_bytes = state["submodule_bytes"]
self.submodule_bytes_store = state["submodule_bytes_store"]
self.loaded_submodule_store = {}
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
""" """
A wrapper around a compiled function by vllm. It will forward the tensor A wrapper around a compiled function by vllm. It will forward the tensor
...@@ -46,6 +163,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -46,6 +163,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
prefix: str, prefix: str,
optimized_call: Callable[..., Any], optimized_call: Callable[..., Any],
is_encoder: bool = False, is_encoder: bool = False,
vllm_backend: Any | None = None,
sym_tensor_indices: list[int] | None = None,
) -> None: ) -> None:
assert isinstance(graph_module, torch.fx.GraphModule) assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module self.graph_module = graph_module
...@@ -54,6 +173,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -54,6 +173,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
self.optimized_call = optimized_call self.optimized_call = optimized_call
self.is_encoder = is_encoder self.is_encoder = is_encoder
self.shape_env = None self.shape_env = None
self.vllm_backend = vllm_backend
self.sym_tensor_indices = sym_tensor_indices
sym_input = next( sym_input = next(
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
) )
...@@ -74,9 +195,15 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -74,9 +195,15 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
state = compiled_fn.__dict__.copy() state = compiled_fn.__dict__.copy()
state.pop("optimized_call") state.pop("optimized_call")
state.pop("shape_env") state.pop("shape_env")
state.pop("vllm_backend", None)
for node in state["graph_module"].graph.nodes: for node in state["graph_module"].graph.nodes:
node.meta.pop("source_fn_stack", None) node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None) node.meta.pop("nn_module_stack", None)
for name, submod in state["graph_module"].named_children():
if hasattr(submod, "graph"):
for node in submod.graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
graph_reducer_override = GraphPickler.reducer_override graph_reducer_override = GraphPickler.reducer_override
...@@ -93,15 +220,36 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -93,15 +220,36 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return type(None), () return type(None), ()
return graph_reducer_override(self, obj) return graph_reducer_override(self, obj)
# Mask off tensor inputs since they are large and not needed. if state.get("sym_tensor_indices"):
# put tensor inputs on meta device since their data
# isn't needed, yet we need the meta for make_copy_and_call
state["example_inputs"] = pytree.tree_map_only( state["example_inputs"] = pytree.tree_map_only(
torch.Tensor, lambda _: None, state["example_inputs"] torch.Tensor,
lambda inp: torch.empty_like(inp, device="meta"),
state["example_inputs"],
)
else:
# mask off all tensor inputs since they are large and not needed.
state["example_inputs"] = pytree.tree_map_only(
torch.Tensor,
lambda inp: torch.empty_like(inp, device="meta"),
state["example_inputs"],
) )
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override): with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
state["graph_module"] = GraphPickler.dumps( state["graph_module"] = GraphPickler.dumps(
state["graph_module"], Options(ops_filter=None) state["graph_module"], Options(ops_filter=None)
) )
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"]) state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
if compiled_fn.vllm_backend:
(
standalone_compile_artifacts,
sym_shape_indices_map,
returns_tuple_map,
) = compiled_fn.vllm_backend.collect_standalone_compile_artifacts()
state["standalone_compile_artifacts"] = standalone_compile_artifacts
state["sym_shape_indices_map"] = sym_shape_indices_map
state["returns_tuple_map"] = returns_tuple_map
return pickle.dumps(state) return pickle.dumps(state)
@classmethod @classmethod
...@@ -111,15 +259,48 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -111,15 +259,48 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
from torch.fx._graph_pickler import GraphPickler from torch.fx._graph_pickler import GraphPickler
from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.fx.experimental.symbolic_shapes import ShapeEnv
from vllm.compilation.backends import VllmBackend
state = pickle.loads(data) state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv()) fake_mode = FakeTensorMode(shape_env=ShapeEnv())
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
state["graph_module"].recompile() state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode) state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
standalone_compile_artifacts = state.pop("standalone_compile_artifacts", None)
sym_shape_indices_map = state.pop("sym_shape_indices_map", {})
returns_tuple_map = state.pop("returns_tuple_map", {})
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
assert standalone_compile_artifacts is not None
submod_names = standalone_compile_artifacts.submodule_names()
num_submods = len(submod_names)
num_artifacts = standalone_compile_artifacts.num_artifacts()
logger.info(
"reconstructing serializable fn from standalone compile "
"artifacts. num_artifacts=%d num_submods=%d",
num_artifacts,
num_submods,
)
fn = reconstruct_serializable_fn_from_mega_artifact(
state=state,
standalone_compile_artifacts=standalone_compile_artifacts,
vllm_config=get_current_vllm_config(),
sym_shape_indices_map=sym_shape_indices_map,
returns_tuple_map=returns_tuple_map,
)
logger.info(
"reconstructed serializable fn from standalone compile artifacts"
)
return fn
# Fall back to standard VllmBackend
from vllm.compilation.backends import VllmBackend
is_encoder = state.get("is_encoder", False) is_encoder = state.get("is_encoder", False)
vllm_backend = VllmBackend( vllm_backend: VllmBackend = VllmBackend(
get_current_vllm_config(), state["prefix"], is_encoder get_current_vllm_config(), state["prefix"], is_encoder
) )
...@@ -152,7 +333,140 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -152,7 +333,140 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return "VllmSerializableFunction" return "VllmSerializableFunction"
def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]: def reconstruct_serializable_fn_from_mega_artifact(
state: dict[str, Any],
standalone_compile_artifacts: "StandaloneCompiledArtifacts",
vllm_config: VllmConfig,
sym_shape_indices_map: dict[str, list[int]],
returns_tuple_map: dict[str, bool],
) -> "VllmSerializableFunction":
"""Construct a VllmSerializableFunction from cached inductor artifacts.
This function reconstructs a callable model from pre-compiled inductor
artifacts without re-running the compilation. It:
1. Loads all cached artifacts
2. Builds compiled callables for each submodule/shape
3. Creates PiecewiseBackend instances that dispatch to cached artifacts
4. Wraps with cudagraph if needed
5. Returns the final VllmSerializableFunction
Note: This function shares similar logic with PiecewiseCompileInterpreter
in backends.py. Both create PiecewiseBackend instances and wrap them with
cudagraph. The key difference is:
- this function: PiecewiseBackend receives pre-compiled runnables
(compiled_runnables is set, graph is None)
- PiecewiseCompileInterpreter: PiecewiseBackend receives the FX graph
to compile (graph is set, compiled_runnables is None)
If modifying the backend creation/wrapping logic, consider updating both.
Args:
state: Deserialized state dict containing graph_module, example_inputs,
prefix, sym_tensor_indices, is_encoder, etc.
standalone_compile_artifacts: The StandaloneCompiledArtifacts containing
pre-compiled artifacts for each submodule/shape combination.
vllm_config: The vLLM configuration.
sym_shape_indices_map: Mapping from submod_name to sym_shape_indices.
returns_tuple_map: Mapping from submod_name to returns_tuple.
Returns:
A VllmSerializableFunction that can be called directly.
"""
from vllm.compilation.backends import (
VllmBackend,
make_copy_and_call,
wrap_with_cudagraph_if_needed,
)
from vllm.compilation.piecewise_backend import PiecewiseBackend
prefix = state["prefix"]
is_encoder = state.get("is_encoder", False)
split_gm = state["graph_module"]
compilation_config = vllm_config.compilation_config
standalone_compile_artifacts.load_all()
submod_names = standalone_compile_artifacts.submodule_names()
compiled_callables: dict[str, dict[str, Callable]] = {}
for cache_key in standalone_compile_artifacts.submodule_bytes:
submod_name, shape_str = cache_key.rsplit("_", 1)
compiled_callables.setdefault(submod_name, {})[shape_str] = (
standalone_compile_artifacts.get_loaded(submod_name, shape_str)
)
vllm_backend = VllmBackend(vllm_config, prefix, is_encoder)
dummy_cache_dir = os.path.join(envs.VLLM_CACHE_ROOT, "dummy_cache")
os.makedirs(dummy_cache_dir, exist_ok=True)
vllm_backend.compiler_manager.initialize_cache(
cache_dir=dummy_cache_dir,
disable_cache=True,
prefix=prefix,
)
# spot check that cached submodules exist in the graph structure
graph_children = {name for name, _ in split_gm.named_children()}
missing = set(submod_names) - graph_children
assert not missing, (
f"artifacts reference submodules not in graph: {missing}. "
f"graph has: {sorted(graph_children)}"
)
for i, submod_name in enumerate(submod_names):
assert submod_name in sym_shape_indices_map and submod_name in returns_tuple_map
sym_shape_indices = sym_shape_indices_map[submod_name]
returns_tuple = returns_tuple_map[submod_name]
runnables = compiled_callables[submod_name]
piecewise_backend = PiecewiseBackend(
graph=None, # not needed for cached artifacts
vllm_config=vllm_config,
piecewise_compile_index=i,
total_piecewise_compiles=len(submod_names),
sym_shape_indices=sym_shape_indices,
vllm_backend=vllm_backend,
returns_tuple=returns_tuple,
compiled_runnables=runnables,
)
is_first = i == 0
is_last = i == len(submod_names) - 1
wrapped_backend = wrap_with_cudagraph_if_needed(
piecewise_backend,
vllm_config,
compilation_config,
is_first,
is_last,
)
split_gm.__dict__[submod_name] = wrapped_backend
logger.debug(
"Replaced submodule %s with piecewise backend from cache",
submod_name,
)
if compilation_config.cudagraph_copy_inputs:
sym_tensor_indices = state["sym_tensor_indices"]
input_buffers = [
torch.empty_like(
state["example_inputs"][idx], device=vllm_config.device_config.device
)
for idx in sym_tensor_indices
]
optimized_call = make_copy_and_call(sym_tensor_indices, input_buffers, split_gm)
else:
optimized_call = split_gm
fn = VllmSerializableFunction(
**state,
optimized_call=optimized_call,
vllm_backend=None,
)
return fn
def aot_compile_hash_factors(vllm_config: VllmConfig) -> list[str]:
factors = [] factors = []
# 0. factors come from the env, for example, The values of # 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affect the computation graph. # VLLM_PP_LAYER_PARTITION will affect the computation graph.
...@@ -163,6 +477,11 @@ def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]: ...@@ -163,6 +477,11 @@ def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]:
# model is created) # model is created)
config_hash = vllm_config.compute_hash() config_hash = vllm_config.compute_hash()
factors.append(config_hash) factors.append(config_hash)
# 2. inductor factors if applicable
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
factors.extend(get_inductor_factors())
return factors return factors
......
...@@ -16,9 +16,12 @@ import vllm.envs as envs ...@@ -16,9 +16,12 @@ import vllm.envs as envs
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.utils import Range from vllm.config.utils import Range
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
class CompilerInterface: class CompilerInterface:
""" """
...@@ -230,12 +233,42 @@ class InductorStandaloneAdaptor(CompilerInterface): ...@@ -230,12 +233,42 @@ class InductorStandaloneAdaptor(CompilerInterface):
from torch._inductor import standalone_compile from torch._inductor import standalone_compile
compiled_graph = standalone_compile( supports_aot = is_torch_equal_or_newer("2.10.0.dev")
graph,
example_inputs, if not supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT:
dynamic_shapes=dynamic_shapes, logger.error(
options={"config_patches": current_config}, "CRITICAL: VLLM_USE_MEGA_AOT_ARTIFACT "
"is enabled but PyTorch version does not support 'aot' "
"parameter in standalone_compile. This requires PyTorch "
"2.10.0+. Falling back to non-AOT mode."
) )
compile_kwargs = {
"dynamic_shapes": dynamic_shapes,
"options": {
"config_patches": current_config,
},
}
use_aot: bool = supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT
# only add 'aot' parameter if both supported and enabled...
# this will set bundled_autograd_cache
# https://github.com/pytorch/pytorch/blob/9bbc5b2905c260adf41bc866a732f9c121a2828a/torch/_inductor/standalone_compile.py#L359 # noqa
if use_aot:
compile_kwargs["aot"] = True # type: ignore[assignment]
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
if use_aot:
from torch._inductor.standalone_compile import AOTCompiledArtifact
assert isinstance(compiled_graph, AOTCompiledArtifact)
assert hasattr(compiled_graph, "serialize")
# just return the compiled graph and a key
# since we can serialize the bytes using to_bytes
# and reload it using the key when reading
return compiled_graph, None
# Save the compiled artifact to disk in the specified path # Save the compiled artifact to disk in the specified path
assert key is not None assert key is not None
path = os.path.join(self.cache_dir, key) path = os.path.join(self.cache_dir, key)
...@@ -619,6 +652,7 @@ def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None: ...@@ -619,6 +652,7 @@ def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
def set_functorch_config() -> None: def set_functorch_config() -> None:
if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
torch._functorch.config.bundled_autograd_cache = False torch._functorch.config.bundled_autograd_cache = False
......
...@@ -320,7 +320,7 @@ def _support_torch_compile( ...@@ -320,7 +320,7 @@ def _support_torch_compile(
return return
self._check_shape_invariants = shape_invariants self._check_shape_invariants = shape_invariants
self.was_aot_compile_fn_loaded_from_disk = False
compilation_counter.num_models_seen += 1 compilation_counter.num_models_seen += 1
self.compiled = False self.compiled = False
...@@ -417,9 +417,9 @@ def _support_torch_compile( ...@@ -417,9 +417,9 @@ def _support_torch_compile(
serialized backend artifacts), then we need to generate a new AOT serialized backend artifacts), then we need to generate a new AOT
compile artifact from scratch. compile artifact from scratch.
""" """
from .caching import compilation_config_hash_factors from .caching import aot_compile_hash_factors
factors: list[str] = compilation_config_hash_factors(self.vllm_config) factors: list[str] = aot_compile_hash_factors(self.vllm_config)
factors.append(_model_hash_key(self.forward)) factors.append(_model_hash_key(self.forward))
hash_key = hashlib.sha256(str(factors).encode()).hexdigest() hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
...@@ -446,6 +446,7 @@ def _support_torch_compile( ...@@ -446,6 +446,7 @@ def _support_torch_compile(
if not self.compilation_config.dynamic_shapes_config.evaluate_guards: if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
loaded_fn.disable_guard_check() loaded_fn.disable_guard_check()
self.aot_compiled_fn = loaded_fn self.aot_compiled_fn = loaded_fn
self.was_aot_compile_fn_loaded_from_disk = True
except Exception as e: except Exception as e:
if os.path.exists(aot_compilation_path): if os.path.exists(aot_compilation_path):
logger.warning( logger.warning(
...@@ -547,26 +548,45 @@ def _support_torch_compile( ...@@ -547,26 +548,45 @@ def _support_torch_compile(
logger.warning("Detected eager backend, disabling AOT compile.") logger.warning("Detected eager backend, disabling AOT compile.")
use_aot_compile = False use_aot_compile = False
if use_aot_compile: if use_aot_compile:
from vllm.compilation.backends import set_on_compilation_complete
# store the path for saving after warmup
self._aot_compilation_path = aot_compilation_path
self._aot_cache_dir = cache_dir
# set callback in context so it's available when compilation completes
with set_on_compilation_complete(self.save_aot_compiled_function):
self.aot_compiled_fn = self.aot_compile(*args, **kwargs) self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
output = self.aot_compiled_fn(self, *args, **kwargs) output = self.aot_compiled_fn(self, *args, **kwargs)
assert aot_compilation_path is not None
assert cache_dir is not None
try:
os.makedirs(cache_dir, exist_ok=True)
self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
except Exception as e:
logger.warning(
"Cannot save aot compilation to path %s, error: %s",
aot_compilation_path,
str(e),
)
else: else:
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type] output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
self.compiled = True self.compiled = True
return output return output
# triggers VllmSerializableFunction.serialize()
def save_aot_compiled_function(self):
if self.was_aot_compile_fn_loaded_from_disk:
logger.debug("AOT compiled function was loaded from cache, skipping save")
return
assert (
self.aot_compiled_fn and self._aot_compilation_path and self._aot_cache_dir
)
logger.info("saving AOT compiled function to %s", self._aot_compilation_path)
try:
os.makedirs(self._aot_cache_dir, exist_ok=True)
self.aot_compiled_fn.save_compiled_function(self._aot_compilation_path)
logger.info("saved AOT compiled function to %s", self._aot_compilation_path)
except Exception as e:
logger.warning(
"unable to save AOT compiled function to %s: %s",
self._aot_compilation_path,
e,
)
cls.__call__ = __call__ cls.__call__ = __call__
cls.save_aot_compiled_function = save_aot_compiled_function
return cls return cls
......
...@@ -2,10 +2,15 @@ ...@@ -2,10 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses import dataclasses
import io
import pickle
from collections.abc import Callable from collections.abc import Callable
from pickle import Pickler
from typing import Any from typing import Any
import torch._functorch.config
import torch.fx as fx import torch.fx as fx
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.compilation.monitor import end_monitoring_torch_compile
...@@ -26,12 +31,14 @@ class RangeEntry: ...@@ -26,12 +31,14 @@ class RangeEntry:
class PiecewiseBackend: class PiecewiseBackend:
def __init__( def __init__(
self, self,
graph: fx.GraphModule, graph: fx.GraphModule | None,
vllm_config: VllmConfig, vllm_config: VllmConfig,
piecewise_compile_index: int, piecewise_compile_index: int,
total_piecewise_compiles: int, total_piecewise_compiles: int,
sym_shape_indices: list[int], sym_shape_indices: list[int],
vllm_backend: VllmBackend, vllm_backend: VllmBackend,
returns_tuple: bool,
compiled_runnables: dict[str, Callable] | None = None,
): ):
""" """
The backend for piecewise compilation. The backend for piecewise compilation.
...@@ -41,13 +48,28 @@ class PiecewiseBackend: ...@@ -41,13 +48,28 @@ class PiecewiseBackend:
We will compile `self.graph` once for the general shape, We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in and then compile for different shapes specified in
`compilation_config.compile_sizes`. `compilation_config.compile_sizes`.
This class supports two mutually exclusive modes:
1. Compilation (graph is set, compiled_runnables is None):
Used during initial compilation when we have the FX graph
and need to compile it for each shape range.
2. Precompilation (graph is None, compiled_runnables is set):
Used when loading from cache/AOT artifacts where we already
have pre-compiled callables and don't need the original graph.
Exactly one of graph or compiled_runnables must be provided.
""" """
assert bool(graph is not None) ^ bool(compiled_runnables is not None), (
"exactly one of graph and compiled_runnables should be set."
)
self.graph = graph self.graph = graph
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.piecewise_compile_index = piecewise_compile_index self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend self.vllm_backend = vllm_backend
self.compiled_runnables = compiled_runnables
self.is_first_graph = piecewise_compile_index == 0 self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
...@@ -77,6 +99,7 @@ class PiecewiseBackend: ...@@ -77,6 +99,7 @@ class PiecewiseBackend:
logger.debug_once(log_string) logger.debug_once(log_string)
self.sym_shape_indices = sym_shape_indices self.sym_shape_indices = sym_shape_indices
self.returns_tuple = returns_tuple
# the entries for ranges that we need to either # the entries for ranges that we need to either
self.range_entries: dict[Range, RangeEntry] = {} self.range_entries: dict[Range, RangeEntry] = {}
...@@ -108,12 +131,71 @@ class PiecewiseBackend: ...@@ -108,12 +131,71 @@ class PiecewiseBackend:
compile_range=range, compile_range=range,
) )
# get the on_compilation_complete callback from context...
# PiecewiseBackend is created during the first call,
# which is when the context is set (see compilation/decorators.py)
from vllm.compilation.backends import _on_compilation_complete_callback
self.on_compilation_complete = _on_compilation_complete_callback.get()
def get_compiled_graph_wrapper(self, compiled_graph):
def compiled_graph_wrapper(*args):
graph_output = compiled_graph(*args)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
# reading the python bytecode correctly in vLLM?
if self.returns_tuple or not isinstance(graph_output, (tuple, list)):
return graph_output
else:
return graph_output[0]
return compiled_graph_wrapper
def check_for_ending_compilation(self) -> None: def check_for_ending_compilation(self) -> None:
if self.is_last_graph and not self.to_be_compiled_ranges: if self.is_last_graph and not self.to_be_compiled_ranges:
# no specific sizes to compile # no specific sizes to compile
# save the hash of the inductor graph for the next run # save the hash of the inductor graph for the next run
self.vllm_backend.compiler_manager.save_to_file() self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config) end_monitoring_torch_compile(self.vllm_config)
# Call the completion callback (e.g., to save AOT compiled function)
if self.on_compilation_complete is not None:
self.on_compilation_complete()
def to_bytes(self) -> dict[str, bytes]:
class StandaloneCompiledArtifactsPickler(Pickler):
def reducer_override(self, obj):
if isinstance(obj, CachingAutotuner):
obj.prepare_for_pickle()
return pickle.loads, (
pickle.dumps(
obj,
),
)
return NotImplemented
def serialize(fn) -> bytes:
assert hasattr(fn, "serialize"), "fn must have serialize method"
with torch._functorch.config.patch("bundled_autograd_cache", True):
entry = fn.serialize()
f = io.BytesIO()
StandaloneCompiledArtifactsPickler(f).dump(entry)
result = f.getvalue()
return result
out = {}
for range_key, entry in self.range_entries.items():
if not entry.compiled:
logger.debug(
"entry with range %s not compiled, so cannot get its bytes",
range_key,
)
continue
if hasattr(entry.runnable, "serialize"):
out[str(range_key)] = serialize(entry.runnable)
return out
def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]: def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
# We need to pass fake example_inputs, otherwise torch.compile # We need to pass fake example_inputs, otherwise torch.compile
...@@ -127,6 +209,7 @@ class PiecewiseBackend: ...@@ -127,6 +209,7 @@ class PiecewiseBackend:
# non fake tensors as example inputs! # non fake tensors as example inputs!
# See issue https://github.com/vllm-project/vllm/issues/27899 # See issue https://github.com/vllm-project/vllm/issues/27899
fake_example_inputs = [] fake_example_inputs = []
assert self.graph is not None
for node in self.graph.graph.nodes: for node in self.graph.graph.nodes:
# All place holders come first # All place holders come first
if node.op == "placeholder": if node.op == "placeholder":
...@@ -140,9 +223,11 @@ class PiecewiseBackend: ...@@ -140,9 +223,11 @@ class PiecewiseBackend:
self, range_entry: RangeEntry, args: tuple[Any, ...] self, range_entry: RangeEntry, args: tuple[Any, ...]
) -> Any: ) -> Any:
if not range_entry.compiled: if not range_entry.compiled:
range_entry.compiled = True if self.compiled_runnables is not None:
self.to_be_compiled_ranges.remove(range_entry.compile_range) range_entry.runnable = self.get_compiled_graph_wrapper(
self.compiled_runnables[str(range_entry.compile_range)]
)
else:
# args are real arguments # args are real arguments
# fakify for range, real args for concrete size. # fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in # For concrete size, we clear the shape env in
...@@ -152,6 +237,10 @@ class PiecewiseBackend: ...@@ -152,6 +237,10 @@ class PiecewiseBackend:
if not range_entry.compile_range.is_single_size() if not range_entry.compile_range.is_single_size()
else list(args) else list(args)
) )
with (
torch._functorch.config.patch("bundled_autograd_cache", True),
):
range_entry.runnable = self.vllm_backend.compiler_manager.compile( range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph, self.graph,
args_list, args_list,
...@@ -162,6 +251,9 @@ class PiecewiseBackend: ...@@ -162,6 +251,9 @@ class PiecewiseBackend:
num_graphs=self.total_piecewise_compiles, num_graphs=self.total_piecewise_compiles,
) )
range_entry.compiled = True
self.to_be_compiled_ranges.remove(range_entry.compile_range)
self.check_for_ending_compilation() self.check_for_ending_compilation()
def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None: def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
......
...@@ -108,6 +108,7 @@ if TYPE_CHECKING: ...@@ -108,6 +108,7 @@ if TYPE_CHECKING:
VLLM_USE_AOT_COMPILE: bool = False VLLM_USE_AOT_COMPILE: bool = False
VLLM_USE_BYTECODE_HOOK: bool = False VLLM_USE_BYTECODE_HOOK: bool = False
VLLM_FORCE_AOT_LOAD: bool = False VLLM_FORCE_AOT_LOAD: bool = False
VLLM_USE_MEGA_AOT_ARTIFACT: bool = False
VLLM_USE_TRITON_AWQ: bool = False VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
...@@ -630,6 +631,13 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -630,6 +631,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
# to load will result in a hard error when this is enabled. # to load will result in a hard error when this is enabled.
# Will be ignored when VLLM_USE_AOT_COMPILE is disabled. # Will be ignored when VLLM_USE_AOT_COMPILE is disabled.
"VLLM_FORCE_AOT_LOAD": lambda: os.environ.get("VLLM_FORCE_AOT_LOAD", "0") == "1", "VLLM_FORCE_AOT_LOAD": lambda: os.environ.get("VLLM_FORCE_AOT_LOAD", "0") == "1",
# Enable loading compiled models directly from cached standalone compile artifacts
# without re-splitting graph modules. This reduces overhead during model
# loading by using reconstruct_serializable_fn_from_mega_artifact.
"VLLM_USE_MEGA_AOT_ARTIFACT": lambda: os.environ.get(
"VLLM_USE_MEGA_AOT_ARTIFACT", "0"
)
== "1",
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment