Unverified Commit 07ea184f authored by Angela Yi's avatar Angela Yi Committed by GitHub
Browse files

[ez] Delete more torch version checks <= 2.8 (#33288)


Signed-off-by: default avatarangelayi <yiangela7@gmail.com>
parent a663b218
...@@ -375,14 +375,13 @@ class InductorAdaptor(CompilerInterface): ...@@ -375,14 +375,13 @@ class InductorAdaptor(CompilerInterface):
# it to get the hash of the compiled graph directly. # it to get the hash of the compiled graph directly.
hash_str, file_path = None, None hash_str, file_path = None, None
from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash from torch._inductor.codecache import compiled_fx_graph_hash
if torch.__version__.startswith("2.5"): def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
original_load = FxGraphCache.load output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
original_load_name = "torch._inductor.codecache.FxGraphCache.load" nonlocal hash_str
inductor_compiled_graph = output
def hijack_load(*args: Any, **kwargs: Any) -> Any: if inductor_compiled_graph is not None:
inductor_compiled_graph = original_load(*args, **kwargs)
nonlocal file_path nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa file_path = compiled_fn.__code__.co_filename # noqa
...@@ -395,44 +394,14 @@ class InductorAdaptor(CompilerInterface): ...@@ -395,44 +394,14 @@ class InductorAdaptor(CompilerInterface):
for cell in compiled_fn.__closure__: for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents): if not callable(cell.cell_contents):
continue continue
if cell.cell_contents.__code__.co_filename.startswith( code = cell.cell_contents.__code__
self.base_cache_dir if code.co_filename.startswith(self.base_cache_dir):
): # this is the real file path
# this is the real file path compiled from Inductor # compiled from Inductor
file_path = cell.cell_contents.__code__.co_filename file_path = code.co_filename
break break
return inductor_compiled_graph hash_str = inductor_compiled_graph._fx_graph_cache_key
return output
hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa
elif torch.__version__ >= "2.6":
# function renamed in 2.6
original_load_name = None
def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
nonlocal hash_str
inductor_compiled_graph = output
if inductor_compiled_graph is not None:
nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa
if (
not file_path.startswith(self.base_cache_dir)
and compiled_fn.__closure__ is not None
):
# hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py
for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents):
continue
code = cell.cell_contents.__code__
if code.co_filename.startswith(self.base_cache_dir):
# this is the real file path
# compiled from Inductor
file_path = code.co_filename
break
hash_str = inductor_compiled_graph._fx_graph_cache_key
return output
def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any: def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any:
out = compiled_fx_graph_hash(*args, **kwargs) out = compiled_fx_graph_hash(*args, **kwargs)
...@@ -453,10 +422,6 @@ class InductorAdaptor(CompilerInterface): ...@@ -453,10 +422,6 @@ class InductorAdaptor(CompilerInterface):
return AlwaysHitShapeEnv() return AlwaysHitShapeEnv()
with ExitStack() as stack: with ExitStack() as stack:
# hijack to get the compiled graph itself
if original_load_name is not None:
stack.enter_context(patch(original_load_name, hijack_load))
# for hijacking the hash of the compiled graph # for hijacking the hash of the compiled graph
stack.enter_context( stack.enter_context(
patch( patch(
...@@ -573,25 +538,16 @@ class InductorAdaptor(CompilerInterface): ...@@ -573,25 +538,16 @@ class InductorAdaptor(CompilerInterface):
# Dynamo metrics context, see method for more details. # Dynamo metrics context, see method for more details.
exit_stack.enter_context(self.metrics_context()) exit_stack.enter_context(self.metrics_context())
if torch.__version__.startswith("2.5"): from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
inductor_compiled_graph = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, False
)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove "
f"the cache directory and try again." # noqa
)
elif torch.__version__ >= "2.6":
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
constants = CompiledFxGraphConstantsWithGm(graph) constants = CompiledFxGraphConstantsWithGm(graph)
inductor_compiled_graph, _ = FxGraphCache._lookup_graph( inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, None, constants hash_str, example_inputs, True, None, constants
) )
assert inductor_compiled_graph is not None, ( assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove " "Inductor cache lookup failed. Please remove "
f"the cache directory and try again." # noqa f"the cache directory and try again." # noqa
) )
# Inductor calling convention (function signature): # Inductor calling convention (function signature):
# f(list) -> tuple # f(list) -> tuple
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import torch import torch
from packaging import version
from vllm.config import CompilationMode, get_current_vllm_config from vllm.config import CompilationMode, get_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -98,9 +97,6 @@ class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel): ...@@ -98,9 +97,6 @@ class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
if compute_capability is not None and compute_capability < 94: if compute_capability is not None and compute_capability < 94:
return False, "requires compute capability 94 and above." return False, "requires compute capability 94 and above."
if not version.parse(torch.__version__) >= version.parse("2.7"):
return False, "requires pytorch version >=2.7."
return True, None return True, None
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment