Unverified Commit e30cedd4 authored by Richard Zou's avatar Richard Zou Committed by GitHub
Browse files

[torch.compile] Stop doing unnecessary FakeTensorProp in PiecewiseCompileInterpreter (#34093)


Signed-off-by: default avatarRichard Zou <zou3519@gmail.com>
parent 3bcd494e
...@@ -27,10 +27,29 @@ from ...utils import create_new_process_for_each_test ...@@ -27,10 +27,29 @@ from ...utils import create_new_process_for_each_test
from ..silly_attention import get_global_counter, reset_global_counter from ..silly_attention import get_global_counter, reset_global_counter
# Custom op that returns an unbacked symint during graph capture
@torch.library.custom_op("mylib::foo", mutates_args=())
def foo(x: torch.Tensor) -> int:
return 3
@foo.register_fake
def _(x):
return torch.library.get_ctx().new_dynamic_size()
@support_torch_compile @support_torch_compile
class SillyModel(nn.Module): class SillyModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
intermediate_unbacked=False,
**kwargs,
) -> None:
super().__init__() super().__init__()
self.intermediate_unbacked = intermediate_unbacked
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
""" """
...@@ -44,6 +63,13 @@ class SillyModel(nn.Module): ...@@ -44,6 +63,13 @@ class SillyModel(nn.Module):
torch.ops.silly.attention(x, x, x, out) torch.ops.silly.attention(x, x, x, out)
x = out x = out
x = x - 2 x = x - 2
if self.intermediate_unbacked:
# Test for unbacked symints: the following is a fancy way to multiply by 1
u0 = foo(x)
ones = x.new_ones(x.shape[0], u0).sum(-1) / 3
x = x * ones
x = x - 1 x = x - 1
out = torch.empty_like(x) out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out) torch.ops.silly.attention(x, x, x, out)
...@@ -52,6 +78,7 @@ class SillyModel(nn.Module): ...@@ -52,6 +78,7 @@ class SillyModel(nn.Module):
return x return x
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
def _run_simple_model( def _run_simple_model(
splitting_ops, splitting_ops,
use_inductor_graph_partition, use_inductor_graph_partition,
...@@ -60,6 +87,8 @@ def _run_simple_model( ...@@ -60,6 +87,8 @@ def _run_simple_model(
expected_num_piecewise_capturable_graphs_seen, expected_num_piecewise_capturable_graphs_seen,
expected_num_backend_compilations, expected_num_backend_compilations,
expected_num_cudagraph_captured, expected_num_cudagraph_captured,
*,
intermediate_unbacked=False,
): ):
vllm_config = VllmConfig( vllm_config = VllmConfig(
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
...@@ -72,7 +101,11 @@ def _run_simple_model( ...@@ -72,7 +101,11 @@ def _run_simple_model(
) )
) )
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix="") model = SillyModel(
vllm_config=vllm_config,
prefix="",
intermediate_unbacked=intermediate_unbacked,
)
inputs = torch.randn(100).cuda() inputs = torch.randn(100).cuda()
...@@ -125,9 +158,10 @@ def _run_simple_model( ...@@ -125,9 +158,10 @@ def _run_simple_model(
@pytest.mark.parametrize("backend", ["inductor", "eager"]) @pytest.mark.parametrize("backend", ["inductor", "eager"])
@pytest.mark.parametrize("intermediate_unbacked", [True, False])
@torch.inference_mode() @torch.inference_mode()
@create_new_process_for_each_test("spawn") @create_new_process_for_each_test("spawn")
def test_simple_piecewise_compile(backend): def test_simple_piecewise_compile(backend, intermediate_unbacked):
_run_simple_model( _run_simple_model(
splitting_ops=["silly::attention"], splitting_ops=["silly::attention"],
use_inductor_graph_partition=False, use_inductor_graph_partition=False,
...@@ -140,6 +174,7 @@ def test_simple_piecewise_compile(backend): ...@@ -140,6 +174,7 @@ def test_simple_piecewise_compile(backend):
expected_num_backend_compilations=3, expected_num_backend_compilations=3,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
expected_num_cudagraph_captured=6, expected_num_cudagraph_captured=6,
intermediate_unbacked=intermediate_unbacked,
) )
......
...@@ -570,7 +570,9 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] ...@@ -570,7 +570,9 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
) -> Any: ) -> Any:
assert isinstance(target, str) assert isinstance(target, str)
output = super().call_module(target, args, kwargs) gm = getattr(self.module, target)
outputs = gm.graph.output_node().args[0]
output = fx.map_arg(outputs, lambda node: node.meta["example_value"])
if target in self.compile_submod_names: if target in self.compile_submod_names:
index = self.compile_submod_names.index(target) index = self.compile_submod_names.index(target)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment