cudagraph_utils.py 15.7 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Callable
4
from typing import Any
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11
12

import torch
import torch.nn as nn
from tqdm import tqdm

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
13
from vllm.forward_context import BatchDescriptor, set_forward_context
14
from vllm.model_executor.offloader.base import get_offloader
15
from vllm.utils.math_utils import cdiv
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.v1.kv_cache_interface import KVCacheConfig
17
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
Woosuk Kwon's avatar
Woosuk Kwon committed
18
from vllm.v1.worker.gpu.block_table import BlockTables
19
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
20
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
21
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
22
from vllm.v1.worker.gpu.model_states.interface import ModelState
23
from vllm.v1.worker.utils import AttentionGroup
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26


class CudaGraphManager:
27
28
29
30
31
32
    def __init__(
        self,
        vllm_config: VllmConfig,
        use_aux_hidden_state_outputs: bool,
        device: torch.device,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
33
        self.vllm_config = vllm_config
34
        self.scheduler_config = vllm_config.scheduler_config
35
        self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
36
37
38
        self.device = device

        self.max_model_len = vllm_config.model_config.max_model_len
39
        self.max_num_reqs = self.scheduler_config.max_num_seqs
40
        self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
41
        self.dp_size = vllm_config.parallel_config.data_parallel_size
42
43
44
45
46
47

        self.uniform_decode_query_len = 1
        spec_config = vllm_config.speculative_config
        if spec_config is not None:
            self.uniform_decode_query_len += spec_config.num_speculative_tokens

Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
        self.compilation_config = vllm_config.compilation_config
        assert self.compilation_config is not None
50
        self.cudagraph_mode = self.compilation_config.cudagraph_mode
51
52
53
54
55
56

        use_uniform_decode_cudagraph = (
            self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
            and self.cudagraph_mode.separate_routine()
        )
        self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
57
58
59
60
            self.compilation_config.cudagraph_capture_sizes,
            self.max_num_reqs,
            self.max_num_tokens,
            self.cudagraph_mode,
61
62
            self.uniform_decode_query_len,
            use_uniform_decode_cudagraph,
63
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65

        self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
66
67
68
        self.pool = None
        if self.cudagraph_mode != CUDAGraphMode.NONE:
            self.pool = torch.cuda.graph_pool_handle()
Woosuk Kwon's avatar
Woosuk Kwon committed
69
        self.hidden_states: torch.Tensor | None = None
70
        self.aux_hidden_states: list[torch.Tensor] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
71
72

    def needs_capture(self) -> bool:
73
        return len(self.cudagraph_sizes) > 0
Woosuk Kwon's avatar
Woosuk Kwon committed
74
75

    def get_cudagraph_size(
76
        self, num_tokens: int, uniform_decode: bool = False
Woosuk Kwon's avatar
Woosuk Kwon committed
77
    ) -> int | None:
78
79
80
        if uniform_decode and self.uniform_decode_cudagraph_sizes:
            return self.uniform_decode_cudagraph_sizes.get(num_tokens)
        return self.cudagraph_sizes.get(num_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
83

    def capture_graph(
        self,
84
        num_tokens: int,
85
        capture_cg_mode: CUDAGraphMode,
Woosuk Kwon's avatar
Woosuk Kwon committed
86
        model: nn.Module,
87
        model_state: ModelState,
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
        input_buffers: InputBuffers,
        block_tables: BlockTables,
90
        attn_groups: list[list[AttentionGroup]],
Woosuk Kwon's avatar
Woosuk Kwon committed
91
        kv_cache_config: KVCacheConfig,
92
93
        has_lora: bool = False,
        uniform_decode: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
94
    ) -> None:
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        # select and check capture function
        assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
            f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
        )
        if capture_cg_mode == CUDAGraphMode.PIECEWISE:
            capture_fn = self._capture_piecewise_graph
        else:
            capture_fn = self._capture_full_graph
        # prepare inputs
        if uniform_decode:
            num_reqs = min(
                cdiv(num_tokens, self.uniform_decode_query_len),
                self.max_num_reqs,
            )
        else:
            num_reqs = min(num_tokens, self.max_num_reqs)
111
112
113
114
115
116
117
118
119

        model_inputs = {
            "input_ids": input_buffers.input_ids[:num_tokens],
            "positions": input_buffers.positions[:num_tokens],
            # NOTE: Values returned by `prepare_dummy_inputs` will override the
            # default values above.
            **model_state.prepare_dummy_inputs(num_reqs, num_tokens),
        }

120
        attn_metadata, slot_mappings = prepare_inputs_to_capture(
121
122
            num_reqs,
            num_tokens,
123
            model_state,
124
125
            input_buffers,
            block_tables,
126
            attn_groups,
127
            kv_cache_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
128
        )
129
        num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
133
134

        # Warm up.
        with set_forward_context(
            attn_metadata,
            self.vllm_config,
135
            num_tokens=num_tokens,
136
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
Woosuk Kwon's avatar
Woosuk Kwon committed
137
            num_tokens_across_dp=num_tokens_across_dp,
138
            slot_mapping=slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
139
        ):
140
            model_output = model(**model_inputs)
141
142
143
144
145
146
147
148
149
150
151
            if self.use_aux_hidden_state_outputs:
                hidden_states, aux_hidden_states = model_output
            else:
                hidden_states = model_output
                aux_hidden_states = None

        # Allocate output buffers if not already done.
        if self.hidden_states is None:
            self.hidden_states = torch.empty_like(hidden_states)
        if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
            self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states]
Woosuk Kwon's avatar
Woosuk Kwon committed
152

153
154
155
156
        capture_fn(
            num_tokens=num_tokens,
            num_reqs=num_reqs,
            model=model,
157
            model_inputs=model_inputs,
158
159
160
161
162
163
164
165
166
167
168
            num_tokens_across_dp=num_tokens_across_dp,
            attn_metadata=attn_metadata,
            slot_mappings=slot_mappings,
            has_lora=has_lora,
        )

    def _capture_full_graph(
        self,
        num_tokens: int,
        num_reqs: int,
        model: nn.Module,
169
        model_inputs: dict[str, torch.Tensor | None],
170
171
172
173
174
175
        num_tokens_across_dp: torch.Tensor,
        attn_metadata: dict[str, Any] | None,
        slot_mappings: dict[str, torch.Tensor] | None,
        has_lora: bool = False,
    ) -> None:
        assert attn_metadata is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
176
        # Capture the graph.
177
        assert num_tokens not in self.graphs
Woosuk Kwon's avatar
Woosuk Kwon committed
178
        graph = torch.cuda.CUDAGraph()
179
180
181
182
183

        # Sync offloader's copy stream before capture.
        # Ensure any pre-capture prefetches from offloader are complete.
        get_offloader().sync_prev_onload()

Woosuk Kwon's avatar
Woosuk Kwon committed
184
185
        with (
            set_forward_context(
186
187
                attn_metadata=attn_metadata,
                vllm_config=self.vllm_config,
188
                num_tokens=num_tokens,
189
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
Woosuk Kwon's avatar
Woosuk Kwon committed
190
                num_tokens_across_dp=num_tokens_across_dp,
191
                slot_mapping=slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
192
193
194
            ),
            torch.cuda.graph(graph, self.pool),
        ):
195
196
            model_output = model(**model_inputs)

197
198
199
200
201
            # Join offloader's copy stream after forward to avoid unjoined
            # stream error. The last layer's start_prefetch forks copy_stream,
            # but wait_prefetch only happens in the next forward pass.
            get_offloader().join_after_forward()

202
203
204
205
206
207
208
            if self.use_aux_hidden_state_outputs:
                hidden_states, aux_hidden_states = model_output
            else:
                hidden_states = model_output
                aux_hidden_states = None

            # Copy outputs to the output buffers.
209
            assert self.hidden_states is not None
210
            self.hidden_states[:num_tokens] = hidden_states
211
212
213
            if self.use_aux_hidden_state_outputs:
                for i, aux_hidden in enumerate(aux_hidden_states):
                    self.aux_hidden_states[i][:num_tokens] = aux_hidden
214
        self.graphs[num_tokens] = graph
Woosuk Kwon's avatar
Woosuk Kwon committed
215

216
217
218
219
220
    def _capture_piecewise_graph(
        self,
        num_tokens: int,
        num_reqs: int,
        model: nn.Module,
221
        model_inputs: dict[str, torch.Tensor | None],
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        num_tokens_across_dp: torch.Tensor,
        attn_metadata: dict[str, Any] | None,
        slot_mappings: dict[str, torch.Tensor] | None,
        has_lora: bool = False,
    ) -> None:
        # create batch descriptor for piecewise cudagraph dispatch key
        batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)

        # Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
        with set_forward_context(
            attn_metadata=None,  # piecewise no need attn_metadata
            vllm_config=self.vllm_config,
            num_tokens=num_tokens,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            num_tokens_across_dp=num_tokens_across_dp,
            batch_descriptor=batch_descriptor,
            slot_mapping=slot_mappings,
        ):
240
            model(**model_inputs)
241

Woosuk Kwon's avatar
Woosuk Kwon committed
242
243
244
245
    @torch.inference_mode()
    def capture(
        self,
        model: nn.Module,
246
        model_state: ModelState,
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
        input_buffers: InputBuffers,
        block_tables: BlockTables,
249
        attn_groups: list[list[AttentionGroup]],
Woosuk Kwon's avatar
Woosuk Kwon committed
250
        kv_cache_config: KVCacheConfig,
251
        has_lora: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
252
    ) -> None:
253
254
255
        common_kwargs = dict(
            device=self.device,
            capture_fn=self.capture_graph,
256
            model=model,
257
            model_state=model_state,
258
259
            input_buffers=input_buffers,
            block_tables=block_tables,
260
            attn_groups=attn_groups,
261
            kv_cache_config=kv_cache_config,
262
            has_lora=has_lora,
263
264
        )

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
        # Phase 1: Capture for mixed prefill-decode batches if needed.
        mixed_mode = self.cudagraph_mode.mixed_mode()
        if mixed_mode != CUDAGraphMode.NONE:
            capture_graphs(
                cudagraph_sizes=self.cudagraph_sizes,
                capture_cudagraph_mode=mixed_mode,
                desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
                uniform_decode=False,
                **common_kwargs,
            )

        # Phase 2: Capture FULL graphs for uniform decode batches if needed.
        # This is only needed if we use a separate routine for decode batches
        # and the decode_mode is FULL.
        if self.uniform_decode_cudagraph_sizes:
            capture_graphs(
                cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
                capture_cudagraph_mode=CUDAGraphMode.FULL,
                desc="Capturing CUDA graphs (decode, FULL)",
                uniform_decode=True,
                **common_kwargs,
            )

    def get_cudagraph_runtime_mode(
        self, num_reqs: int, num_tokens: int, max_query_len: int
    ) -> tuple[CUDAGraphMode, int | None]:
        is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
            num_tokens == max_query_len * num_reqs
        )

        cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
        if cudagraph_size is None:
            cudagraph_mode = CUDAGraphMode.NONE
        elif is_uniform_decode:
            cudagraph_mode = self.cudagraph_mode.decode_mode()
        else:
            cudagraph_mode = self.cudagraph_mode.mixed_mode()
302
303
304
305
306
307
308
309
310
311

        if (
            cudagraph_mode == CUDAGraphMode.FULL
            and cudagraph_size is not None
            and cudagraph_size not in self.graphs
        ):
            # If graph wasn't captured yet, fall back to eager.
            # This might happen when the dummy run is called before capture.
            cudagraph_mode = CUDAGraphMode.NONE
            cudagraph_size = None
312
313
        return cudagraph_mode, cudagraph_size

314
315
316
    def run_fullgraph(
        self, num_tokens: int
    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
317
        assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
318
319
320
321
322
323
324
        # Sync offloader before replay - needed when transitioning from
        # eager/piecewise to full cudagraph (e.g., prefill → decode).
        # The previous eager iteration's start_prefetch may have queued
        # H2D copies on copy_stream that the graph's captured events
        # cannot see. Without this, replay could overwrite static buffers
        # while those copies are still in flight.
        get_offloader().sync_prev_onload()
325
        self.graphs[num_tokens].replay()
Woosuk Kwon's avatar
Woosuk Kwon committed
326
        assert self.hidden_states is not None
327
328
329
330
        hidden_states = self.hidden_states[:num_tokens]
        if not self.use_aux_hidden_state_outputs:
            return hidden_states
        return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states]
331
332
333
334
335
336
337


def get_cudagraph_sizes(
    capture_sizes: list[int] | None,
    max_num_reqs: int,
    max_num_tokens: int,
    cudagraph_mode: CUDAGraphMode,
338
339
340
341
342
343
    uniform_decode_query_len: int = 1,
    uniform_decode_cudagraph: bool = False,
) -> tuple[dict[int, int], dict[int, int]]:
    # Support both FULL and PIECEWISE cudagraph modes
    if cudagraph_mode == CUDAGraphMode.NONE:
        return {}, {}
344
    if not capture_sizes:
345
        return {}, {}
346
347
348

    capture_sizes = sorted(capture_sizes)
    if not capture_sizes:
349
        return {}, {}
350
351
352
353
354
355
356

    cudagraph_sizes: dict[int, int] = {}
    for i in range(1, capture_sizes[-1] + 1):
        for x in capture_sizes:
            if i <= x:
                cudagraph_sizes[i] = x
                break
357

358
359
360
361
362
363
364
365
366
    uniform_decode_cudagraph_sizes: dict[int, int] = {}
    if uniform_decode_cudagraph:
        max_num_tokens = max_num_reqs * uniform_decode_query_len
        uniform_decode_cudagraph_sizes = {
            k: v
            for k, v in cudagraph_sizes.items()
            if v <= max_num_tokens and v >= uniform_decode_query_len
        }
    return cudagraph_sizes, uniform_decode_cudagraph_sizes
367
368
369
370
371
372


def capture_graphs(
    cudagraph_sizes: dict[int, int],
    device: torch.device,
    capture_fn: Callable,
373
374
    capture_cudagraph_mode: CUDAGraphMode,
    desc: str = "Capturing CUDA graphs",
375
376
377
378
379
    **capture_kwargs,
) -> None:
    # Capture larger graphs first.
    sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
    if is_global_first_rank():
380
        sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
381
382
383

    with graph_capture(device=device):
        for size in sizes_to_capture:
384
            capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
385
386
387
388
389


def prepare_inputs_to_capture(
    num_reqs: int,
    num_tokens: int,
390
    model_state: ModelState,
391
392
    input_buffers: InputBuffers,
    block_tables: BlockTables,
393
    attn_groups: list[list[AttentionGroup]],
394
    kv_cache_config: KVCacheConfig,
395
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
396
    input_batch = InputBatch.make_dummy(num_reqs, num_tokens, input_buffers)
397
398
    input_block_tables = block_tables.get_dummy_block_tables(num_reqs)
    slot_mappings = block_tables.get_dummy_slot_mappings(num_tokens)
399
400
401
    slot_mappings_by_layer = build_slot_mappings_by_layer(
        slot_mappings, kv_cache_config
    )
402

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    # HACK(woosuk): Special handling for DCP.
    if block_tables.cp_size > 1:
        prepare_dcp_local_seq_lens(
            input_buffers.dcp_local_seq_lens,
            input_batch.seq_lens,
            num_reqs,
            block_tables.cp_size,
            block_tables.cp_rank,
            block_tables.cp_interleave,
        )
        input_batch.dcp_local_seq_lens = input_buffers.dcp_local_seq_lens[:num_reqs]

    attn_metadata = model_state.prepare_attn(
        input_batch,
        input_block_tables,
        slot_mappings,
        attn_groups,
        kv_cache_config,
421
    )
422
    return attn_metadata, slot_mappings_by_layer