cudagraph_dispatcher.py 7.58 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from itertools import product
4

5
from vllm.config import CUDAGraphMode, VllmConfig
6
from vllm.forward_context import BatchDescriptor
7
8
9
from vllm.logger import init_logger

logger = init_logger(__name__)
10
11
12
13


class CudagraphDispatcher:
    """
14
15
    Runtime cudagraph dispatcher to dispatch keys for multiple set of
    cudagraphs.
16
17

    The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
18
19
    for FULL cudagraph runtime mode. The keys are initialized depending on
    attention support and what cudagraph mode is set in CompilationConfig. The
20
21
22
    keys stored in dispatcher are the only source of truth for valid
    cudagraphs that can be dispatched at runtime.

23
    At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
24
    PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
25
    based on the input key. After dispatching (communicated via forward
26
27
28
    context), the cudagraph wrappers will trust the dispatch key to either
    capture or replay (if the mode matches), or pass through to the underlying
    runnable without cudagraph (if the mode does not match or mode is NONE).
29
30
31
32
33
    """

    def __init__(self, vllm_config: VllmConfig):
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
34
35
36
37
38
        self.uniform_decode_query_len = (
            1
            if not self.vllm_config.speculative_config
            else 1 + self.vllm_config.speculative_config.num_speculative_tokens
        )
39
40
41
42
43
44
45

        # Dict to store valid cudagraph dispatching keys.
        self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
            CUDAGraphMode.PIECEWISE: set(),
            CUDAGraphMode.FULL: set(),
        }

46
        assert (
47
            not self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
48
49
            or self.compilation_config.is_attention_compiled_piecewise()
        ), (
50
            "Compilation mode should be CompilationMode.VLLM_COMPILE when "
51
52
53
            "cudagraph_mode piecewise cudagraphs is used, "
            "and attention should be in splitting_ops or "
            "inductor splitting should be used. "
54
            f"cudagraph_mode={self.compilation_config.cudagraph_mode}, "
55
            f"compilation_mode={self.compilation_config.mode}, "
56
            f"splitting_ops={self.compilation_config.splitting_ops}"
57
        )
58
59
60

        self.keys_initialized = False

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    def _create_padded_batch_descriptor(
        self, num_tokens: int, uniform_decode: bool, has_lora: bool
    ) -> BatchDescriptor:
        max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
        uniform_decode_query_len = self.uniform_decode_query_len
        num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)

        if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
            num_reqs = num_tokens_padded // uniform_decode_query_len
            assert num_tokens_padded % uniform_decode_query_len == 0
        else:
            uniform_decode = False
            num_reqs = min(num_tokens_padded, max_num_seqs)

        return BatchDescriptor(
            num_tokens=num_tokens_padded,
            num_reqs=num_reqs,
            uniform=uniform_decode,
            has_lora=has_lora,
        )

82
83
84
85
    def add_cudagraph_key(
        self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor
    ):
        assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
86
            f"Invalid cudagraph runtime mode for keys: {runtime_mode}"
87
        )
88
89
        self.cudagraph_keys[runtime_mode].add(batch_descriptor)

90
91
92
    def initialize_cudagraph_keys(
        self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
    ):
93
94
95
        # This should be called only after attention backend is initialized. So we can
        # get the correct cudagraph mode after backend support is resolved.
        self.cudagraph_mode = cudagraph_mode
96

97
98
99
100
101
102
103
104
105
        # LoRA activation cases to specialize the cuda graphs on
        if self.vllm_config.lora_config:
            if self.compilation_config.cudagraph_specialize_lora:
                lora_cases = [True, False]
            else:
                lora_cases = [True]
        else:
            lora_cases = [False]

106
107
        # Note: we create all valid keys for cudagraph here but do not
        # guarantee all keys would be used. For example, if we allow lazy
108
109
        # capturing in future PR, some keys may never be triggered.
        if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
110
111
112
            for bs, has_lora in product(
                self.compilation_config.cudagraph_capture_sizes, lora_cases
            ):
113
114
                self.add_cudagraph_key(
                    cudagraph_mode.mixed_mode(),
115
116
117
                    self._create_padded_batch_descriptor(
                        bs, False, has_lora
                    ).relax_for_mixed_batch_cudagraphs(),
118
                )
119
120
121

        # if decode cudagraph mode is FULL, and we don't already have mixed
        # mode full cudagraphs then add them here.
122
123
124
125
126
127
128
129
        if (
            cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
            and cudagraph_mode.separate_routine()
        ):
            max_num_tokens = (
                uniform_decode_query_len
                * self.vllm_config.scheduler_config.max_num_seqs
            )
130
            cudagraph_capture_sizes_for_decode = [
131
132
                x
                for x in self.compilation_config.cudagraph_capture_sizes
133
134
                if x <= max_num_tokens and x >= uniform_decode_query_len
            ]
135
            for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
136
137
                self.add_cudagraph_key(
                    CUDAGraphMode.FULL,
138
                    self._create_padded_batch_descriptor(bs, True, has_lora),
139
                )
140

141
142
143
        self.keys_initialized = True

    def dispatch(
144
145
146
147
        self,
        num_tokens: int,
        uniform_decode: bool,
        has_lora: bool,
148
        disable_full: bool = False,
149
    ) -> tuple[CUDAGraphMode, BatchDescriptor]:
150
        """
151
152
        Given conditions(e.g.,batch descriptor and if using cascade attention),
        dispatch to a cudagraph runtime mode and the valid batch descriptor.
153
        A new batch descriptor is returned as we might dispatch a uniform batch
154
155
        to a graph that supports a more general batch (uniform to non-uniform).
        """
156
157
158
159
160
161
162
163
164
165
166
        if (
            not self.keys_initialized
            or self.cudagraph_mode == CUDAGraphMode.NONE
            or num_tokens > self.compilation_config.max_cudagraph_capture_size
        ):
            return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)

        batch_desc = self._create_padded_batch_descriptor(
            num_tokens, uniform_decode, has_lora
        )
        relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
167

168
        if not disable_full:
169
            # check if key exists for full cudagraph
170
171
            if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
                return CUDAGraphMode.FULL, batch_desc
172

173
174
175
            # otherwise, check if the relaxed key exists
            if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
                return CUDAGraphMode.FULL, relaxed_batch_desc
176

177
        # also check if the relaxed key exists for more "general"
178
        # piecewise cudagraph
179
180
        if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
            return CUDAGraphMode.PIECEWISE, relaxed_batch_desc
181

182
183
        # finally, just return no cudagraphs and a trivial batch descriptor
        return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)