partition_rules.py 2.16 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import contextlib

6
import torch
7
8
9
10
11
12

from vllm.logger import init_logger

logger = init_logger(__name__)


13
14
15
16
17
18
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
    """
    Check if a node should be split for dynamo graph partition.
    It operates on dynamo graph, so the node.target can be anything.
    We need to check and split only on OpOverload and OpOverloadPacket.
    """
19

20
21
    if node.op != "call_function":
        return False
22

23
    target = node.target
24

25
26
27
    if isinstance(target, torch._ops.OpOverloadPacket):
        # Example: "aten::add"
        return target._qualified_op_name in splitting_ops
28

29
30
31
32
33
34
35
36
37
    if isinstance(target, torch._ops.OpOverload):
        # Example: "aten::add"
        packet_name = target.name()

        # Example: "aten::add.default"
        op_overload_name = f"{packet_name}.{target._overloadname}"
        return op_overload_name in splitting_ops or packet_name in splitting_ops

    return False
38
39
40


@contextlib.contextmanager
41
def inductor_partition_rule_context(splitting_ops: list[str]):
42
43
44
45
46
47
48
    """Context manager to temporarily register Inductor partition rules.

    Registers custom partition rules for specified operators, forcing the
    Inductor scheduler to partition the graph at these operators. The rules
    are automatically restored to their previous state on exit.

    Args:
49
        splitting_ops: List of operator names to partition on.
50
    """
51
    if not splitting_ops:
52
53
54
55
56
57
        logger.debug("No partition ops provided; skipping rule registration.")
        yield
        return

    # Save current state before registering

58
59
60
61
    saved_splitting_ops: list[str] = list(
        torch._inductor.config.custom_should_partition_ops
    )
    torch._inductor.config.custom_should_partition_ops = splitting_ops
62

63
64
65
    logger.debug(
        "Registered inductor partition rules for %d operators", len(splitting_ops)
    )
66
67
68
69
70

    try:
        yield
    finally:
        # Clear and restore previous state
71
        torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
72
        logger.debug("Restored previous partition rules state.")