Unverified Commit ac32e66c authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[torch.compile] Reorganize vllm/compilation and tests/compile (0/N for vLLM IR) (#33731)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Signed-off-by: default avatarProExpertProg <luka.govedic@gmail.com>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent f79d9dce
...@@ -10,8 +10,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized ...@@ -10,8 +10,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .fx_utils import is_func from ..fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -9,8 +9,8 @@ from torch.fx.experimental.symbolic_shapes import statically_known_true ...@@ -9,8 +9,8 @@ from torch.fx.experimental.symbolic_shapes import statically_known_true
from vllm.logger import init_logger from vllm.logger import init_logger
from .fx_utils import is_func from ..fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from torch import fx from torch import fx
from vllm.compilation.vllm_inductor_pass import VllmInductorPass from ..vllm_inductor_pass import VllmInductorPass
class PostCleanupPass(VllmInductorPass): class PostCleanupPass(VllmInductorPass):
......
...@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal ...@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal
from pydantic import Field, TypeAdapter, field_validator from pydantic import Field, TypeAdapter, field_validator
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.passes.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import ( from vllm.config.utils import (
Range, Range,
config, config,
...@@ -170,7 +170,9 @@ class PassConfig: ...@@ -170,7 +170,9 @@ class PassConfig:
@staticmethod @staticmethod
def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]: def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]:
from vllm.compilation.collective_fusion import FI_ALLREDUCE_FUSION_MAX_SIZE_MB from vllm.compilation.passes.fusion.allreduce_rms_fusion import (
FI_ALLREDUCE_FUSION_MAX_SIZE_MB,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
if not current_platform.is_cuda(): if not current_platform.is_cuda():
......
...@@ -191,7 +191,7 @@ class Platform: ...@@ -191,7 +191,7 @@ class Platform:
Get the pass manager class for this platform. Get the pass manager class for this platform.
It will be registered as a custom pass under the current_platform.pass_key. It will be registered as a custom pass under the current_platform.pass_key.
""" """
return "vllm.compilation.pass_manager.PostGradPassManager" return "vllm.compilation.passes.pass_manager.PostGradPassManager"
@classmethod @classmethod
def get_compile_backend(cls) -> str: def get_compile_backend(cls) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment