Unverified Commit e5ff1402 authored by Jiangyun Zhu's avatar Jiangyun Zhu Committed by GitHub
Browse files

[cudagraph] fix cudagraph warning in deepseekv32 (#28044)


Signed-off-by: default avatarzjy0516 <riverclouds.zhu@qq.com>
parent 0a6a3a12
......@@ -184,3 +184,56 @@ def test_consecutive_ops_in_split():
assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [
"call_function"
] + ["output"]
def test_empty_only_partition_is_merged():
"""
Test that an empty-allocation-only partition is merged into its previous
partition during Dynamo FX splitting.
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
y = torch.sin(x)
out = torch.empty_like(y)
torch.ops.aten.cos.out(y, out=out)
return out
x = torch.randn(4, 3)
gm = make_fx(model_fn)(x)
split_ops = ["aten::sin", "aten::cos.out"]
split_gm, split_items = split_graph(gm, split_ops)
# Without the merge, this graph is split into 3 partitions where the
# middle partition contains only aten::empty_like.
assert len(split_items) == 2, "Empty-only partition should be merged"
output_original = gm(x)
output_split = split_gm(x)
assert torch.allclose(output_original, output_split), "Output mismatch after split"
def test_builtin_empty_only_partition_is_merged():
"""
In Dynamo graphs, torch.empty/empty_like may appear as builtin call targets
(not aten OpOverload). Ensure empty-only partitions are still merged.
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
out1 = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out1)
out2 = torch.empty_like(x)
torch.ops.silly.attention(out1, out1, out1, out2)
return out2
gm = torch.fx.symbolic_trace(model_fn)
split_gm, split_items = split_graph(gm, ["silly::attention"])
# Without the empty-only merge, this graph creates 4 partitions:
# [empty_like], [attention], [empty_like], [attention].
assert len(split_items) == 3, "Builtin empty-only partition should be merged"
x = torch.randn(2, 3, device="cuda")
output_original = gm(x)
output_split = split_gm(x)
assert torch.allclose(output_original, output_split), "Output mismatch after split"
......@@ -9,6 +9,7 @@ import operator
import os
import pprint
import time
from collections import defaultdict
from collections.abc import Callable, Generator, Sequence
from contextlib import contextmanager
from copy import deepcopy
......@@ -405,6 +406,58 @@ class SplitItem:
graph: fx.GraphModule
def _is_empty_allocation_node(node: fx.Node) -> bool:
if node.op == "call_method":
return node.target == "new_empty"
if node.op != "call_function":
return False
target = node.target
if target in (torch.empty, torch.empty_like, torch.empty_strided):
return True
if isinstance(target, torch._ops.OpOverloadPacket):
packet_name = target._qualified_op_name
elif isinstance(target, torch._ops.OpOverload):
packet_name = target.name()
else:
return False
return packet_name.startswith("aten::empty") or packet_name.startswith(
"aten::new_empty"
)
def _merge_empty_only_subgraphs(
node_to_subgraph_id: dict[fx.Node, int],
) -> None:
"""
Merge a partition that only contains an empty allocation op into the
previous partition. This avoids generating standalone empty submodules,
which can lead to empty cudagraph captures.
"""
nodes_by_subgraph_id: dict[int, list[fx.Node]] = defaultdict(list)
subgraph_id_order: list[int] = []
for node, subgraph_id in node_to_subgraph_id.items():
if subgraph_id not in nodes_by_subgraph_id:
subgraph_id_order.append(subgraph_id)
nodes_by_subgraph_id[subgraph_id].append(node)
prev_subgraph_id: int | None = None
for subgraph_id in subgraph_id_order:
nodes = nodes_by_subgraph_id[subgraph_id]
if (
len(nodes) == 1
and _is_empty_allocation_node(nodes[0])
and prev_subgraph_id is not None
):
node_to_subgraph_id[nodes[0]] = prev_subgraph_id
continue
prev_subgraph_id = subgraph_id
def split_graph(
graph: fx.GraphModule, splitting_ops: list[str]
) -> tuple[fx.GraphModule, list[SplitItem]]:
......@@ -443,6 +496,8 @@ def split_graph(
else:
node_to_subgraph_id[node] = subgraph_id
_merge_empty_only_subgraphs(node_to_subgraph_id)
# `keep_original_order` is important!
# otherwise pytorch might reorder the nodes and
# the semantics of the graph will change when we
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment