Unverified Commit 762be26a authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[Bugfix] Upgrade depyf to 0.19 and streamline custom pass logging (#20777)


Signed-off-by: default avatarLuka Govedic <lgovedic@redhat.com>
Signed-off-by: default avatarluka <lgovedic@redhat.com>
parent 6a9e6b2a
...@@ -40,7 +40,7 @@ six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that need ...@@ -40,7 +40,7 @@ six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that need
setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
einops # Required for Qwen2-VL. einops # Required for Qwen2-VL.
compressed-tensors == 0.10.2 # required for compressed-tensors compressed-tensors == 0.10.2 # required for compressed-tensors
depyf==0.18.0 # required for profiling and debugging with compilation config depyf==0.19.0 # required for profiling and debugging with compilation config
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
watchfiles # required for http server to monitor the updates of TLS files watchfiles # required for http server to monitor the updates of TLS files
python-json-logger # Used by logging as per examples/others/logging_configuration.md python-json-logger # Used by logging as per examples/others/logging_configuration.md
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import tempfile
from typing import Any, Optional, Union from typing import Any, Optional, Union
import pytest import pytest
...@@ -111,6 +112,11 @@ def test_full_graph( ...@@ -111,6 +112,11 @@ def test_full_graph(
pass_config=PassConfig(enable_fusion=True, pass_config=PassConfig(enable_fusion=True,
enable_noop=True)), model) enable_noop=True)), model)
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
] + [
# Test depyf integration works
(CompilationConfig(level=CompilationLevel.PIECEWISE,
debug_dump_path=tempfile.gettempdir()),
("facebook/opt-125m", {})),
]) ])
# only test some of the models # only test some of the models
@create_new_process_for_each_test() @create_new_process_for_each_test()
......
...@@ -6,13 +6,7 @@ import time ...@@ -6,13 +6,7 @@ import time
import torch import torch
from torch._dynamo.utils import lazy_format_graph_code from torch._dynamo.utils import lazy_format_graph_code
from vllm.config import PassConfig, VllmConfig from vllm.config import VllmConfig
# yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import (
get_tensor_model_parallel_world_size as get_tp_world_size)
from vllm.distributed import model_parallel_is_initialized as p_is_init
# yapf: enable
from vllm.logger import init_logger from vllm.logger import init_logger
from .inductor_pass import InductorPass from .inductor_pass import InductorPass
...@@ -34,22 +28,9 @@ class VllmInductorPass(InductorPass): ...@@ -34,22 +28,9 @@ class VllmInductorPass(InductorPass):
else None else None
self.pass_name = self.__class__.__name__ self.pass_name = self.__class__.__name__
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False): def dump_graph(self, graph: torch.fx.Graph, stage: str):
lazy_format_graph_code(stage, graph.owning_module) lazy_format_graph_code(stage, graph.owning_module)
if stage in self.pass_config.dump_graph_stages or always:
# Make sure filename includes rank in the distributed setting
parallel = p_is_init() and get_tp_world_size() > 1
rank = f"-{get_tp_rank()}" if parallel else ""
filepath = self.pass_config.dump_graph_dir / f"{stage}{rank}.py"
logger.info("%s printing graph to %s", self.pass_name, filepath)
with open(filepath, "w") as f:
src = graph.python_code(root_module="self", verbose=True).src
# Add imports so it's not full of errors
print("import torch; from torch import device", file=f)
print(src, file=f)
def begin(self): def begin(self):
self._start_time = time.perf_counter_ns() self._start_time = time.perf_counter_ns()
...@@ -61,10 +42,9 @@ class VllmInductorPass(InductorPass): ...@@ -61,10 +42,9 @@ class VllmInductorPass(InductorPass):
class PrinterInductorPass(VllmInductorPass): class PrinterInductorPass(VllmInductorPass):
def __init__(self, name: str, config: PassConfig, always=False): def __init__(self, name: str, config: VllmConfig):
super().__init__(config) super().__init__(config)
self.name = name self.name = name
self.always = always
def __call__(self, graph: torch.fx.Graph): def __call__(self, graph: torch.fx.Graph):
self.dump_graph(graph, self.name, always=self.always) self.dump_graph(graph, self.name)
...@@ -16,7 +16,6 @@ from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, ...@@ -16,7 +16,6 @@ from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
replace) replace)
from functools import cached_property from functools import cached_property
from importlib.util import find_spec from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol, TypeVar, Union, cast, get_args) Protocol, TypeVar, Union, cast, get_args)
...@@ -3953,11 +3952,6 @@ class PassConfig: ...@@ -3953,11 +3952,6 @@ class PassConfig:
don't all have access to full configuration - that would create a cycle as don't all have access to full configuration - that would create a cycle as
the `PassManager` is set as a property of config.""" the `PassManager` is set as a property of config."""
dump_graph_stages: list[str] = field(default_factory=list)
"""List of stages for which we want to dump the graph. Each pass defines
its own stages (before, after, maybe in-between)."""
dump_graph_dir: Path = Path(".")
"""Directory to dump the graphs."""
enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1)
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.""" """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
enable_attn_fusion: bool = False enable_attn_fusion: bool = False
...@@ -3975,12 +3969,9 @@ class PassConfig: ...@@ -3975,12 +3969,9 @@ class PassConfig:
""" """
Produces a hash unique to the pass configuration. Produces a hash unique to the pass configuration.
Any new fields that affect compilation should be added to the hash. Any new fields that affect compilation should be added to the hash.
Do not include dump_graph_* in the hash - they don't affect Any future fields that don't affect compilation should be excluded.
compilation.
""" """
exclude = {"dump_graph_stages", "dump_graph_dir"} return InductorPass.hash_dict(asdict(self))
dict_ = {k: v for k, v in asdict(self).items() if k not in exclude}
return InductorPass.hash_dict(dict_)
def __post_init__(self) -> None: def __post_init__(self) -> None:
if not self.enable_noop: if not self.enable_noop:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment