partition_rules.py 2.8 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import contextlib
5
import logging
6

7
import torch
8
9
10
11
12
13
14
from torch._library.utils import lookup_op

from vllm.logger import init_logger

logger = init_logger(__name__)


15
def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    """Resolve operator names to OpOverload objects.

    Skips operators that fail to resolve (e.g., operators not registered or
    model-specific operators not present in the current model).

    Note: Users should inspect the operator graph before lowering and ensure
    the specified operators are present in the final graph. Built-in PyTorch
    operators (aten::*, torch::*) may be decomposed, fused, or transformed
    during Inductor's compilation passes, so use them with caution.

    Args:
        op_names: List of operator names in PyTorch format
            (e.g., "vllm::unified_attention")

    Returns:
        List of successfully resolved operator overloads
    """
    resolved = []
    for op_name in op_names:
        try:
            resolved.append(lookup_op(op_name))
        except Exception:
            # Skip operators that don't exist (e.g., model-specific ops)
39
40
41
42
43
44
45
46
47
48
            # Do not warn for attention ops, warn for others
            # (most likely manually specified)
            from vllm.config import CompilationConfig

            logger.log(
                logging.DEBUG
                if op_name in CompilationConfig._attention_ops
                else logging.WARNING,
                "Failed to resolve operator for CUDAGraph partition: %s",
                op_name,
49
50
51
52
53
54
55
            )
            continue

    return resolved


@contextlib.contextmanager
56
def inductor_partition_rule_context(splitting_ops: list[str]):
57
58
59
60
61
62
63
    """Context manager to temporarily register Inductor partition rules.

    Registers custom partition rules for specified operators, forcing the
    Inductor scheduler to partition the graph at these operators. The rules
    are automatically restored to their previous state on exit.

    Args:
64
        splitting_ops: List of operator names to partition on.
65
    """
66
    if not splitting_ops:
67
68
69
70
71
72
        logger.debug("No partition ops provided; skipping rule registration.")
        yield
        return

    # Save current state before registering

73
74
75
76
    saved_splitting_ops: list[str] = list(
        torch._inductor.config.custom_should_partition_ops
    )
    torch._inductor.config.custom_should_partition_ops = splitting_ops
77

78
79
80
    logger.debug(
        "Registered inductor partition rules for %d operators", len(splitting_ops)
    )
81
82
83
84
85

    try:
        yield
    finally:
        # Clear and restore previous state
86
        torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
87
        logger.debug("Restored previous partition rules state.")