Unverified Commit 11ae016b authored by Lucia Fang's avatar Lucia Fang Committed by GitHub
Browse files

[torch.compile] Passing only necessary compilation config to inductor pass config (#27041)


Signed-off-by: default avatarLu Fang <fanglu@fb.com>
Co-authored-by: default avatarLucia (Lu) Fang <fanglu@meta.com>
parent 41d30719
...@@ -341,6 +341,15 @@ def async_tp_pass_on_test_model( ...@@ -341,6 +341,15 @@ def async_tp_pass_on_test_model(
async_tp_pass = AsyncTPPass(vllm_config) async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass) backend = TestBackend(async_tp_pass)
assert (
async_tp_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
async_tp_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
hidden_states = torch.randn( hidden_states = torch.randn(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import pytest import pytest
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode from vllm.config.compilation import CompilationMode
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
...@@ -25,6 +28,20 @@ def test_use_cudagraphs_dynamic(): ...@@ -25,6 +28,20 @@ def test_use_cudagraphs_dynamic():
assert vllm_config.compilation_config.use_cudagraph assert vllm_config.compilation_config.use_cudagraph
def test_copy_pass():
vllm_config = VllmConfig()
inductor_pass = FixFunctionalizationPass(vllm_config)
copied_inductor_pass = copy.deepcopy(inductor_pass)
assert (
copied_inductor_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
assert (
copied_inductor_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
def test_custom_op(): def test_custom_op():
# proper syntax # proper syntax
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"]) _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])
......
...@@ -285,6 +285,14 @@ def sequence_parallelism_pass_on_test_model( ...@@ -285,6 +285,14 @@ def sequence_parallelism_pass_on_test_model(
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config) sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
assert (
sequence_parallelism_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
func_pass = FixFunctionalizationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import functools import functools
import operator import operator
import time import time
import weakref from dataclasses import dataclass
from typing import ClassVar from typing import ClassVar
import regex as re import regex as re
...@@ -19,6 +19,12 @@ from .inductor_pass import InductorPass ...@@ -19,6 +19,12 @@ from .inductor_pass import InductorPass
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class InductorCompilationConfig:
splitting_ops: list[str] | None = None
use_inductor_graph_partition: bool = False
class VllmInductorPass(InductorPass): class VllmInductorPass(InductorPass):
""" """
An inductor pass with access to vLLM PassConfig. An inductor pass with access to vLLM PassConfig.
...@@ -29,7 +35,12 @@ class VllmInductorPass(InductorPass): ...@@ -29,7 +35,12 @@ class VllmInductorPass(InductorPass):
"""Keep track of pass index for debug dump ordering.""" """Keep track of pass index for debug dump ordering."""
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
self.compilation_config = weakref.proxy(config.compilation_config) # Get only the necessary CompilationConfig for the inductor pass, since
# full `CompilationConfig` contains pointer to model which is unsafe.
self.compilation_config = InductorCompilationConfig(
splitting_ops=config.compilation_config.splitting_ops,
use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition,
)
self.pass_config = config.compilation_config.pass_config self.pass_config = config.compilation_config.pass_config
self.model_dtype = config.model_config.dtype if config.model_config else None self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None self.device = config.device_config.device if config.device_config else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment