Unverified Commit a6e72e1e authored by James Wu's avatar James Wu Committed by GitHub
Browse files

[Bugfix] [pytorch] Patch AOTAutogradCache._get_shape_env (#17142)


Signed-off-by: default avatarJames Wu <jjwu@meta.com>
parent 5e83a727
......@@ -195,7 +195,6 @@ class InductorAdaptor(CompilerInterface):
hash_str, file_path = None, None
from torch._inductor.codecache import (FxGraphCache,
compiled_fx_graph_hash)
if torch.__version__.startswith("2.5"):
original_load = FxGraphCache.load
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
......@@ -280,6 +279,16 @@ class InductorAdaptor(CompilerInterface):
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env))
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache)
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
_get_shape_env))
# for forcing the graph to be cached
stack.enter_context(
patch(
......@@ -325,11 +334,19 @@ class InductorAdaptor(CompilerInterface):
assert isinstance(handle[1], str)
hash_str = handle[0]
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache)
from torch._inductor.codecache import FxGraphCache
with ExitStack() as exit_stack:
exit_stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()))
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
exit_stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()))
# Dynamo metrics context, see method for more details.
exit_stack.enter_context(self.metrics_context())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment