cudagraph_utils.py 17.2 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Callable
4
from typing import Any
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11
12
13

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
14
from vllm.forward_context import BatchDescriptor, set_forward_context
15
from vllm.model_executor.offloader.base import get_offloader
16
from vllm.utils.math_utils import cdiv
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.v1.kv_cache_interface import KVCacheConfig
18
19
20
21
from vllm.v1.worker.gpu.attn_utils import (
    build_attn_metadata,
    build_slot_mappings_by_layer,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm.v1.worker.gpu.block_table import BlockTables
23
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from vllm.v1.worker.gpu.input_batch import InputBuffers
25
from vllm.v1.worker.utils import AttentionGroup
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28


class CudaGraphManager:
29
30
31
32
33
34
35
    def __init__(
        self,
        vllm_config: VllmConfig,
        uses_mrope: bool,
        use_aux_hidden_state_outputs: bool,
        device: torch.device,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
36
        self.vllm_config = vllm_config
37
        self.scheduler_config = vllm_config.scheduler_config
38
        self.uses_mrope = uses_mrope
39
        self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
42
        self.device = device

        self.max_model_len = vllm_config.model_config.max_model_len
43
        self.max_num_reqs = self.scheduler_config.max_num_seqs
44
        self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
45
        self.dp_size = vllm_config.parallel_config.data_parallel_size
46
47
48
49
50
51

        self.uniform_decode_query_len = 1
        spec_config = vllm_config.speculative_config
        if spec_config is not None:
            self.uniform_decode_query_len += spec_config.num_speculative_tokens

Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
        self.compilation_config = vllm_config.compilation_config
        assert self.compilation_config is not None
54
        self.cudagraph_mode = self.compilation_config.cudagraph_mode
55
56
57
58
59
60

        use_uniform_decode_cudagraph = (
            self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
            and self.cudagraph_mode.separate_routine()
        )
        self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
61
62
63
64
            self.compilation_config.cudagraph_capture_sizes,
            self.max_num_reqs,
            self.max_num_tokens,
            self.cudagraph_mode,
65
66
            self.uniform_decode_query_len,
            use_uniform_decode_cudagraph,
67
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
68
69

        self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
70
71
72
        self.pool = None
        if self.cudagraph_mode != CUDAGraphMode.NONE:
            self.pool = torch.cuda.graph_pool_handle()
Woosuk Kwon's avatar
Woosuk Kwon committed
73
        self.hidden_states: torch.Tensor | None = None
74
        self.aux_hidden_states: list[torch.Tensor] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76

    def needs_capture(self) -> bool:
77
        return len(self.cudagraph_sizes) > 0
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79

    def get_cudagraph_size(
80
        self, num_tokens: int, uniform_decode: bool = False
Woosuk Kwon's avatar
Woosuk Kwon committed
81
    ) -> int | None:
82
83
84
        if uniform_decode and self.uniform_decode_cudagraph_sizes:
            return self.uniform_decode_cudagraph_sizes.get(num_tokens)
        return self.cudagraph_sizes.get(num_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87

    def capture_graph(
        self,
88
        num_tokens: int,
89
        capture_cg_mode: CUDAGraphMode,
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
        model: nn.Module,
        input_buffers: InputBuffers,
92
        mrope_positions: torch.Tensor | None,
93
        inputs_embeds: torch.Tensor | None,
Woosuk Kwon's avatar
Woosuk Kwon committed
94
        block_tables: BlockTables,
95
        attn_groups: list[list[AttentionGroup]],
Woosuk Kwon's avatar
Woosuk Kwon committed
96
        kv_cache_config: KVCacheConfig,
97
98
        has_lora: bool = False,
        uniform_decode: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
99
    ) -> None:
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        # select and check capture function
        assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
            f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
        )
        if capture_cg_mode == CUDAGraphMode.PIECEWISE:
            capture_fn = self._capture_piecewise_graph
        else:
            capture_fn = self._capture_full_graph
        # prepare inputs
        if uniform_decode:
            num_reqs = min(
                cdiv(num_tokens, self.uniform_decode_query_len),
                self.max_num_reqs,
            )
        else:
            num_reqs = min(num_tokens, self.max_num_reqs)
116
        input_ids = input_buffers.input_ids[:num_tokens]
117
118
119
120
        positions = input_buffers.positions[:num_tokens]
        if self.uses_mrope:
            assert mrope_positions is not None
            positions = mrope_positions[:, :num_tokens]
121
122
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[:num_tokens]
123
        attn_metadata, slot_mappings = prepare_inputs_to_capture(
124
125
126
127
            num_reqs,
            num_tokens,
            input_buffers,
            block_tables,
128
            attn_groups,
129
130
            self.max_model_len,
            kv_cache_config,
131
132
133
            uniform_decode_query_len=(
                self.uniform_decode_query_len if uniform_decode else 0
            ),
Woosuk Kwon's avatar
Woosuk Kwon committed
134
        )
135
        num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
139
140

        # Warm up.
        with set_forward_context(
            attn_metadata,
            self.vllm_config,
141
            num_tokens=num_tokens,
142
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
Woosuk Kwon's avatar
Woosuk Kwon committed
143
            num_tokens_across_dp=num_tokens_across_dp,
144
            slot_mapping=slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
145
        ):
146
            model_output = model(
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
                input_ids=input_ids,
                positions=positions,
149
                inputs_embeds=inputs_embeds,
Woosuk Kwon's avatar
Woosuk Kwon committed
150
            )
151
152
153
154
155
156
157
158
159
160
161
            if self.use_aux_hidden_state_outputs:
                hidden_states, aux_hidden_states = model_output
            else:
                hidden_states = model_output
                aux_hidden_states = None

        # Allocate output buffers if not already done.
        if self.hidden_states is None:
            self.hidden_states = torch.empty_like(hidden_states)
        if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
            self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states]
Woosuk Kwon's avatar
Woosuk Kwon committed
162

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        capture_fn(
            num_tokens=num_tokens,
            num_reqs=num_reqs,
            model=model,
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            num_tokens_across_dp=num_tokens_across_dp,
            attn_metadata=attn_metadata,
            slot_mappings=slot_mappings,
            has_lora=has_lora,
        )

    def _capture_full_graph(
        self,
        num_tokens: int,
        num_reqs: int,
        model: nn.Module,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        inputs_embeds: torch.Tensor | None,
        num_tokens_across_dp: torch.Tensor,
        attn_metadata: dict[str, Any] | None,
        slot_mappings: dict[str, torch.Tensor] | None,
        has_lora: bool = False,
    ) -> None:
        assert attn_metadata is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
190
        # Capture the graph.
191
        assert num_tokens not in self.graphs
Woosuk Kwon's avatar
Woosuk Kwon committed
192
        graph = torch.cuda.CUDAGraph()
193
194
195
196
197

        # Sync offloader's copy stream before capture.
        # Ensure any pre-capture prefetches from offloader are complete.
        get_offloader().sync_prev_onload()

Woosuk Kwon's avatar
Woosuk Kwon committed
198
199
        with (
            set_forward_context(
200
201
                attn_metadata=attn_metadata,
                vllm_config=self.vllm_config,
202
                num_tokens=num_tokens,
203
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
Woosuk Kwon's avatar
Woosuk Kwon committed
204
                num_tokens_across_dp=num_tokens_across_dp,
205
                slot_mapping=slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
206
207
208
            ),
            torch.cuda.graph(graph, self.pool),
        ):
209
            model_output = model(
Woosuk Kwon's avatar
Woosuk Kwon committed
210
211
                input_ids=input_ids,
                positions=positions,
212
                inputs_embeds=inputs_embeds,
Woosuk Kwon's avatar
Woosuk Kwon committed
213
            )
214
215
216
217
218
            # Join offloader's copy stream after forward to avoid unjoined
            # stream error. The last layer's start_prefetch forks copy_stream,
            # but wait_prefetch only happens in the next forward pass.
            get_offloader().join_after_forward()

219
220
221
222
223
224
225
            if self.use_aux_hidden_state_outputs:
                hidden_states, aux_hidden_states = model_output
            else:
                hidden_states = model_output
                aux_hidden_states = None

            # Copy outputs to the output buffers.
226
            assert self.hidden_states is not None
227
            self.hidden_states[:num_tokens] = hidden_states
228
229
230
            if self.use_aux_hidden_state_outputs:
                for i, aux_hidden in enumerate(aux_hidden_states):
                    self.aux_hidden_states[i][:num_tokens] = aux_hidden
231
        self.graphs[num_tokens] = graph
Woosuk Kwon's avatar
Woosuk Kwon committed
232

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    def _capture_piecewise_graph(
        self,
        num_tokens: int,
        num_reqs: int,
        model: nn.Module,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        inputs_embeds: torch.Tensor | None,
        num_tokens_across_dp: torch.Tensor,
        attn_metadata: dict[str, Any] | None,
        slot_mappings: dict[str, torch.Tensor] | None,
        has_lora: bool = False,
    ) -> None:
        # create batch descriptor for piecewise cudagraph dispatch key
        batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)

        # Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
        with set_forward_context(
            attn_metadata=None,  # piecewise no need attn_metadata
            vllm_config=self.vllm_config,
            num_tokens=num_tokens,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            num_tokens_across_dp=num_tokens_across_dp,
            batch_descriptor=batch_descriptor,
            slot_mapping=slot_mappings,
        ):
259
            model(
260
261
262
263
264
                input_ids=input_ids,
                positions=positions,
                inputs_embeds=inputs_embeds,
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
265
266
267
268
269
    @torch.inference_mode()
    def capture(
        self,
        model: nn.Module,
        input_buffers: InputBuffers,
270
        mrope_positions: torch.Tensor | None,
271
        inputs_embeds: torch.Tensor | None,
Woosuk Kwon's avatar
Woosuk Kwon committed
272
        block_tables: BlockTables,
273
        attn_groups: list[list[AttentionGroup]],
Woosuk Kwon's avatar
Woosuk Kwon committed
274
        kv_cache_config: KVCacheConfig,
275
        has_lora: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
276
    ) -> None:
277
278
279
        common_kwargs = dict(
            device=self.device,
            capture_fn=self.capture_graph,
280
281
            model=model,
            input_buffers=input_buffers,
282
            mrope_positions=mrope_positions,
283
            inputs_embeds=inputs_embeds,
284
            block_tables=block_tables,
285
            attn_groups=attn_groups,
286
            kv_cache_config=kv_cache_config,
287
            has_lora=has_lora,
288
289
        )

290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
        # Phase 1: Capture for mixed prefill-decode batches if needed.
        mixed_mode = self.cudagraph_mode.mixed_mode()
        if mixed_mode != CUDAGraphMode.NONE:
            capture_graphs(
                cudagraph_sizes=self.cudagraph_sizes,
                capture_cudagraph_mode=mixed_mode,
                desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
                uniform_decode=False,
                **common_kwargs,
            )

        # Phase 2: Capture FULL graphs for uniform decode batches if needed.
        # This is only needed if we use a separate routine for decode batches
        # and the decode_mode is FULL.
        if self.uniform_decode_cudagraph_sizes:
            capture_graphs(
                cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
                capture_cudagraph_mode=CUDAGraphMode.FULL,
                desc="Capturing CUDA graphs (decode, FULL)",
                uniform_decode=True,
                **common_kwargs,
            )

    def get_cudagraph_runtime_mode(
        self, num_reqs: int, num_tokens: int, max_query_len: int
    ) -> tuple[CUDAGraphMode, int | None]:
        is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
            num_tokens == max_query_len * num_reqs
        )

        cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
        if cudagraph_size is None:
            cudagraph_mode = CUDAGraphMode.NONE
        elif is_uniform_decode:
            cudagraph_mode = self.cudagraph_mode.decode_mode()
        else:
            cudagraph_mode = self.cudagraph_mode.mixed_mode()
327
328
329
330
331
332
333
334
335
336

        if (
            cudagraph_mode == CUDAGraphMode.FULL
            and cudagraph_size is not None
            and cudagraph_size not in self.graphs
        ):
            # If graph wasn't captured yet, fall back to eager.
            # This might happen when the dummy run is called before capture.
            cudagraph_mode = CUDAGraphMode.NONE
            cudagraph_size = None
337
338
        return cudagraph_mode, cudagraph_size

339
340
341
    def run_fullgraph(
        self, num_tokens: int
    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
342
        assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
343
344
345
346
347
348
349
        # Sync offloader before replay - needed when transitioning from
        # eager/piecewise to full cudagraph (e.g., prefill → decode).
        # The previous eager iteration's start_prefetch may have queued
        # H2D copies on copy_stream that the graph's captured events
        # cannot see. Without this, replay could overwrite static buffers
        # while those copies are still in flight.
        get_offloader().sync_prev_onload()
350
        self.graphs[num_tokens].replay()
Woosuk Kwon's avatar
Woosuk Kwon committed
351
        assert self.hidden_states is not None
352
353
354
355
        hidden_states = self.hidden_states[:num_tokens]
        if not self.use_aux_hidden_state_outputs:
            return hidden_states
        return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states]
356
357
358
359
360
361
362


def get_cudagraph_sizes(
    capture_sizes: list[int] | None,
    max_num_reqs: int,
    max_num_tokens: int,
    cudagraph_mode: CUDAGraphMode,
363
364
365
366
367
368
    uniform_decode_query_len: int = 1,
    uniform_decode_cudagraph: bool = False,
) -> tuple[dict[int, int], dict[int, int]]:
    # Support both FULL and PIECEWISE cudagraph modes
    if cudagraph_mode == CUDAGraphMode.NONE:
        return {}, {}
369
    if not capture_sizes:
370
        return {}, {}
371
372
373

    capture_sizes = sorted(capture_sizes)
    if not capture_sizes:
374
        return {}, {}
375
376
377
378
379
380
381

    cudagraph_sizes: dict[int, int] = {}
    for i in range(1, capture_sizes[-1] + 1):
        for x in capture_sizes:
            if i <= x:
                cudagraph_sizes[i] = x
                break
382

383
384
385
386
387
388
389
390
391
    uniform_decode_cudagraph_sizes: dict[int, int] = {}
    if uniform_decode_cudagraph:
        max_num_tokens = max_num_reqs * uniform_decode_query_len
        uniform_decode_cudagraph_sizes = {
            k: v
            for k, v in cudagraph_sizes.items()
            if v <= max_num_tokens and v >= uniform_decode_query_len
        }
    return cudagraph_sizes, uniform_decode_cudagraph_sizes
392
393
394
395
396
397


def capture_graphs(
    cudagraph_sizes: dict[int, int],
    device: torch.device,
    capture_fn: Callable,
398
399
    capture_cudagraph_mode: CUDAGraphMode,
    desc: str = "Capturing CUDA graphs",
400
401
402
403
404
    **capture_kwargs,
) -> None:
    # Capture larger graphs first.
    sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
    if is_global_first_rank():
405
        sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
406
407
408

    with graph_capture(device=device):
        for size in sizes_to_capture:
409
            capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
410
411
412
413
414
415
416


def prepare_inputs_to_capture(
    num_reqs: int,
    num_tokens: int,
    input_buffers: InputBuffers,
    block_tables: BlockTables,
417
    attn_groups: list[list[AttentionGroup]],
418
419
    max_model_len: int,
    kv_cache_config: KVCacheConfig,
420
    uniform_decode_query_len: int = 0,
421
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
422
423
424
425
    if uniform_decode_query_len > 0:
        num_tokens_per_req = uniform_decode_query_len
    else:
        num_tokens_per_req = num_tokens // num_reqs
426
427
428
429
430
431
432

    query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
    query_start_loc_np[-1] = num_tokens
    query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
    input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
    input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
    query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
433

434
    # HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
435
    # rather than max_model_len.
436
    input_buffers.seq_lens[:num_reqs] = num_tokens
437
438
    input_buffers.seq_lens[num_reqs:] = 0

439
440
    input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
    input_buffers.dcp_local_seq_lens[num_reqs:] = 0
441

442
443
    input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
    slot_mappings = block_tables.slot_mappings[:, :num_tokens]
444
445
446
    slot_mappings_by_layer = build_slot_mappings_by_layer(
        slot_mappings, kv_cache_config
    )
447
448

    attn_metadata = build_attn_metadata(
449
        attn_groups=attn_groups,
450
451
        num_reqs=num_reqs,
        num_tokens=num_tokens,
452
453
        query_start_loc_gpu=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
454
        max_query_len=num_tokens_per_req,
455
        seq_lens=input_buffers.seq_lens,
456
        max_seq_len=max_model_len,
457
458
459
        block_tables=input_block_tables,
        slot_mappings=slot_mappings,
        kv_cache_config=kv_cache_config,
460
        dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
461
    )
462
    return attn_metadata, slot_mappings_by_layer