Unverified Commit a6e72e1e authored by James Wu's avatar James Wu Committed by GitHub
Browse files

[Bugfix] [pytorch] Patch AOTAutogradCache._get_shape_env (#17142)


Signed-off-by: default avatarJames Wu <jjwu@meta.com>
parent 5e83a727
...@@ -195,7 +195,6 @@ class InductorAdaptor(CompilerInterface): ...@@ -195,7 +195,6 @@ class InductorAdaptor(CompilerInterface):
hash_str, file_path = None, None hash_str, file_path = None, None
from torch._inductor.codecache import (FxGraphCache, from torch._inductor.codecache import (FxGraphCache,
compiled_fx_graph_hash) compiled_fx_graph_hash)
if torch.__version__.startswith("2.5"): if torch.__version__.startswith("2.5"):
original_load = FxGraphCache.load original_load = FxGraphCache.load
original_load_name = "torch._inductor.codecache.FxGraphCache.load" original_load_name = "torch._inductor.codecache.FxGraphCache.load"
...@@ -280,6 +279,16 @@ class InductorAdaptor(CompilerInterface): ...@@ -280,6 +279,16 @@ class InductorAdaptor(CompilerInterface):
patch("torch._inductor.codecache.FxGraphCache._get_shape_env", patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env)) _get_shape_env))
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache)
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
_get_shape_env))
# for forcing the graph to be cached # for forcing the graph to be cached
stack.enter_context( stack.enter_context(
patch( patch(
...@@ -325,11 +334,19 @@ class InductorAdaptor(CompilerInterface): ...@@ -325,11 +334,19 @@ class InductorAdaptor(CompilerInterface):
assert isinstance(handle[1], str) assert isinstance(handle[1], str)
hash_str = handle[0] hash_str = handle[0]
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache)
from torch._inductor.codecache import FxGraphCache from torch._inductor.codecache import FxGraphCache
with ExitStack() as exit_stack: with ExitStack() as exit_stack:
exit_stack.enter_context( exit_stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache._get_shape_env", patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv())) lambda *args, **kwargs: AlwaysHitShapeEnv()))
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
exit_stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()))
# Dynamo metrics context, see method for more details. # Dynamo metrics context, see method for more details.
exit_stack.enter_context(self.metrics_context()) exit_stack.enter_context(self.metrics_context())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment