Unverified Commit e5ff1402 authored by Jiangyun Zhu's avatar Jiangyun Zhu Committed by GitHub
Browse files

[cudagraph] fix cudagraph warning in deepseekv32 (#28044)


Signed-off-by: default avatarzjy0516 <riverclouds.zhu@qq.com>
parent 0a6a3a12
...@@ -184,3 +184,56 @@ def test_consecutive_ops_in_split(): ...@@ -184,3 +184,56 @@ def test_consecutive_ops_in_split():
assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [ assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [
"call_function" "call_function"
] + ["output"] ] + ["output"]
def test_empty_only_partition_is_merged():
"""
Test that an empty-allocation-only partition is merged into its previous
partition during Dynamo FX splitting.
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
y = torch.sin(x)
out = torch.empty_like(y)
torch.ops.aten.cos.out(y, out=out)
return out
x = torch.randn(4, 3)
gm = make_fx(model_fn)(x)
split_ops = ["aten::sin", "aten::cos.out"]
split_gm, split_items = split_graph(gm, split_ops)
# Without the merge, this graph is split into 3 partitions where the
# middle partition contains only aten::empty_like.
assert len(split_items) == 2, "Empty-only partition should be merged"
output_original = gm(x)
output_split = split_gm(x)
assert torch.allclose(output_original, output_split), "Output mismatch after split"
def test_builtin_empty_only_partition_is_merged():
"""
In Dynamo graphs, torch.empty/empty_like may appear as builtin call targets
(not aten OpOverload). Ensure empty-only partitions are still merged.
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
out1 = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out1)
out2 = torch.empty_like(x)
torch.ops.silly.attention(out1, out1, out1, out2)
return out2
gm = torch.fx.symbolic_trace(model_fn)
split_gm, split_items = split_graph(gm, ["silly::attention"])
# Without the empty-only merge, this graph creates 4 partitions:
# [empty_like], [attention], [empty_like], [attention].
assert len(split_items) == 3, "Builtin empty-only partition should be merged"
x = torch.randn(2, 3, device="cuda")
output_original = gm(x)
output_split = split_gm(x)
assert torch.allclose(output_original, output_split), "Output mismatch after split"
...@@ -9,6 +9,7 @@ import operator ...@@ -9,6 +9,7 @@ import operator
import os import os
import pprint import pprint
import time import time
from collections import defaultdict
from collections.abc import Callable, Generator, Sequence from collections.abc import Callable, Generator, Sequence
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
...@@ -405,6 +406,58 @@ class SplitItem: ...@@ -405,6 +406,58 @@ class SplitItem:
graph: fx.GraphModule graph: fx.GraphModule
def _is_empty_allocation_node(node: fx.Node) -> bool:
if node.op == "call_method":
return node.target == "new_empty"
if node.op != "call_function":
return False
target = node.target
if target in (torch.empty, torch.empty_like, torch.empty_strided):
return True
if isinstance(target, torch._ops.OpOverloadPacket):
packet_name = target._qualified_op_name
elif isinstance(target, torch._ops.OpOverload):
packet_name = target.name()
else:
return False
return packet_name.startswith("aten::empty") or packet_name.startswith(
"aten::new_empty"
)
def _merge_empty_only_subgraphs(
node_to_subgraph_id: dict[fx.Node, int],
) -> None:
"""
Merge a partition that only contains an empty allocation op into the
previous partition. This avoids generating standalone empty submodules,
which can lead to empty cudagraph captures.
"""
nodes_by_subgraph_id: dict[int, list[fx.Node]] = defaultdict(list)
subgraph_id_order: list[int] = []
for node, subgraph_id in node_to_subgraph_id.items():
if subgraph_id not in nodes_by_subgraph_id:
subgraph_id_order.append(subgraph_id)
nodes_by_subgraph_id[subgraph_id].append(node)
prev_subgraph_id: int | None = None
for subgraph_id in subgraph_id_order:
nodes = nodes_by_subgraph_id[subgraph_id]
if (
len(nodes) == 1
and _is_empty_allocation_node(nodes[0])
and prev_subgraph_id is not None
):
node_to_subgraph_id[nodes[0]] = prev_subgraph_id
continue
prev_subgraph_id = subgraph_id
def split_graph( def split_graph(
graph: fx.GraphModule, splitting_ops: list[str] graph: fx.GraphModule, splitting_ops: list[str]
) -> tuple[fx.GraphModule, list[SplitItem]]: ) -> tuple[fx.GraphModule, list[SplitItem]]:
...@@ -443,6 +496,8 @@ def split_graph( ...@@ -443,6 +496,8 @@ def split_graph(
else: else:
node_to_subgraph_id[node] = subgraph_id node_to_subgraph_id[node] = subgraph_id
_merge_empty_only_subgraphs(node_to_subgraph_id)
# `keep_original_order` is important! # `keep_original_order` is important!
# otherwise pytorch might reorder the nodes and # otherwise pytorch might reorder the nodes and
# the semantics of the graph will change when we # the semantics of the graph will change when we
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment