cudagraph_utils.py 16.5 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Callable
4
from typing import Any
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11
12
13

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
14
from vllm.forward_context import BatchDescriptor, set_forward_context
15
from vllm.model_executor.offloader.base import get_offloader
16
from vllm.utils.math_utils import cdiv
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.v1.kv_cache_interface import KVCacheConfig
18
19
20
21
from vllm.v1.worker.gpu.attn_utils import (
    build_attn_metadata,
    build_slot_mappings_by_layer,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm.v1.worker.gpu.block_table import BlockTables
23
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from vllm.v1.worker.gpu.input_batch import InputBuffers
25
from vllm.v1.worker.gpu.model_states.interface import ModelState
26
from vllm.v1.worker.utils import AttentionGroup
Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
29


class CudaGraphManager:
30
31
32
33
34
35
    def __init__(
        self,
        vllm_config: VllmConfig,
        use_aux_hidden_state_outputs: bool,
        device: torch.device,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
36
        self.vllm_config = vllm_config
37
        self.scheduler_config = vllm_config.scheduler_config
38
        self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
41
        self.device = device

        self.max_model_len = vllm_config.model_config.max_model_len
42
        self.max_num_reqs = self.scheduler_config.max_num_seqs
43
        self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
44
        self.dp_size = vllm_config.parallel_config.data_parallel_size
45
46
47
48
49
50

        self.uniform_decode_query_len = 1
        spec_config = vllm_config.speculative_config
        if spec_config is not None:
            self.uniform_decode_query_len += spec_config.num_speculative_tokens

Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
        self.compilation_config = vllm_config.compilation_config
        assert self.compilation_config is not None
53
        self.cudagraph_mode = self.compilation_config.cudagraph_mode
54
55
56
57
58
59

        use_uniform_decode_cudagraph = (
            self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
            and self.cudagraph_mode.separate_routine()
        )
        self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
60
61
62
63
            self.compilation_config.cudagraph_capture_sizes,
            self.max_num_reqs,
            self.max_num_tokens,
            self.cudagraph_mode,
64
65
            self.uniform_decode_query_len,
            use_uniform_decode_cudagraph,
66
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68

        self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
69
70
71
        self.pool = None
        if self.cudagraph_mode != CUDAGraphMode.NONE:
            self.pool = torch.cuda.graph_pool_handle()
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        self.hidden_states: torch.Tensor | None = None
73
        self.aux_hidden_states: list[torch.Tensor] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
74
75

    def needs_capture(self) -> bool:
76
        return len(self.cudagraph_sizes) > 0
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78

    def get_cudagraph_size(
79
        self, num_tokens: int, uniform_decode: bool = False
Woosuk Kwon's avatar
Woosuk Kwon committed
80
    ) -> int | None:
81
82
83
        if uniform_decode and self.uniform_decode_cudagraph_sizes:
            return self.uniform_decode_cudagraph_sizes.get(num_tokens)
        return self.cudagraph_sizes.get(num_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
86

    def capture_graph(
        self,
87
        num_tokens: int,
88
        capture_cg_mode: CUDAGraphMode,
Woosuk Kwon's avatar
Woosuk Kwon committed
89
        model: nn.Module,
90
        model_state: ModelState,
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
        input_buffers: InputBuffers,
        block_tables: BlockTables,
93
        attn_groups: list[list[AttentionGroup]],
Woosuk Kwon's avatar
Woosuk Kwon committed
94
        kv_cache_config: KVCacheConfig,
95
96
        has_lora: bool = False,
        uniform_decode: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
97
    ) -> None:
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        # select and check capture function
        assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
            f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
        )
        if capture_cg_mode == CUDAGraphMode.PIECEWISE:
            capture_fn = self._capture_piecewise_graph
        else:
            capture_fn = self._capture_full_graph
        # prepare inputs
        if uniform_decode:
            num_reqs = min(
                cdiv(num_tokens, self.uniform_decode_query_len),
                self.max_num_reqs,
            )
        else:
            num_reqs = min(num_tokens, self.max_num_reqs)
114
115
116
117
118
119
120
121
122

        model_inputs = {
            "input_ids": input_buffers.input_ids[:num_tokens],
            "positions": input_buffers.positions[:num_tokens],
            # NOTE: Values returned by `prepare_dummy_inputs` will override the
            # default values above.
            **model_state.prepare_dummy_inputs(num_reqs, num_tokens),
        }

123
        attn_metadata, slot_mappings = prepare_inputs_to_capture(
124
125
126
127
            num_reqs,
            num_tokens,
            input_buffers,
            block_tables,
128
            attn_groups,
129
130
            self.max_model_len,
            kv_cache_config,
131
132
133
            uniform_decode_query_len=(
                self.uniform_decode_query_len if uniform_decode else 0
            ),
Woosuk Kwon's avatar
Woosuk Kwon committed
134
        )
135
        num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
139
140

        # Warm up.
        with set_forward_context(
            attn_metadata,
            self.vllm_config,
141
            num_tokens=num_tokens,
142
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
Woosuk Kwon's avatar
Woosuk Kwon committed
143
            num_tokens_across_dp=num_tokens_across_dp,
144
            slot_mapping=slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
145
        ):
146
            model_output = model(**model_inputs)
147
148
149
150
151
152
153
154
155
156
157
            if self.use_aux_hidden_state_outputs:
                hidden_states, aux_hidden_states = model_output
            else:
                hidden_states = model_output
                aux_hidden_states = None

        # Allocate output buffers if not already done.
        if self.hidden_states is None:
            self.hidden_states = torch.empty_like(hidden_states)
        if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
            self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states]
Woosuk Kwon's avatar
Woosuk Kwon committed
158

159
160
161
162
        capture_fn(
            num_tokens=num_tokens,
            num_reqs=num_reqs,
            model=model,
163
            model_inputs=model_inputs,
164
165
166
167
168
169
170
171
172
173
174
            num_tokens_across_dp=num_tokens_across_dp,
            attn_metadata=attn_metadata,
            slot_mappings=slot_mappings,
            has_lora=has_lora,
        )

    def _capture_full_graph(
        self,
        num_tokens: int,
        num_reqs: int,
        model: nn.Module,
175
        model_inputs: dict[str, torch.Tensor | None],
176
177
178
179
180
181
        num_tokens_across_dp: torch.Tensor,
        attn_metadata: dict[str, Any] | None,
        slot_mappings: dict[str, torch.Tensor] | None,
        has_lora: bool = False,
    ) -> None:
        assert attn_metadata is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
182
        # Capture the graph.
183
        assert num_tokens not in self.graphs
Woosuk Kwon's avatar
Woosuk Kwon committed
184
        graph = torch.cuda.CUDAGraph()
185
186
187
188
189

        # Sync offloader's copy stream before capture.
        # Ensure any pre-capture prefetches from offloader are complete.
        get_offloader().sync_prev_onload()

Woosuk Kwon's avatar
Woosuk Kwon committed
190
191
        with (
            set_forward_context(
192
193
                attn_metadata=attn_metadata,
                vllm_config=self.vllm_config,
194
                num_tokens=num_tokens,
195
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
Woosuk Kwon's avatar
Woosuk Kwon committed
196
                num_tokens_across_dp=num_tokens_across_dp,
197
                slot_mapping=slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
198
199
200
            ),
            torch.cuda.graph(graph, self.pool),
        ):
201
202
            model_output = model(**model_inputs)

203
204
205
206
207
            # Join offloader's copy stream after forward to avoid unjoined
            # stream error. The last layer's start_prefetch forks copy_stream,
            # but wait_prefetch only happens in the next forward pass.
            get_offloader().join_after_forward()

208
209
210
211
212
213
214
            if self.use_aux_hidden_state_outputs:
                hidden_states, aux_hidden_states = model_output
            else:
                hidden_states = model_output
                aux_hidden_states = None

            # Copy outputs to the output buffers.
215
            assert self.hidden_states is not None
216
            self.hidden_states[:num_tokens] = hidden_states
217
218
219
            if self.use_aux_hidden_state_outputs:
                for i, aux_hidden in enumerate(aux_hidden_states):
                    self.aux_hidden_states[i][:num_tokens] = aux_hidden
220
        self.graphs[num_tokens] = graph
Woosuk Kwon's avatar
Woosuk Kwon committed
221

222
223
224
225
226
    def _capture_piecewise_graph(
        self,
        num_tokens: int,
        num_reqs: int,
        model: nn.Module,
227
        model_inputs: dict[str, torch.Tensor | None],
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        num_tokens_across_dp: torch.Tensor,
        attn_metadata: dict[str, Any] | None,
        slot_mappings: dict[str, torch.Tensor] | None,
        has_lora: bool = False,
    ) -> None:
        # create batch descriptor for piecewise cudagraph dispatch key
        batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)

        # Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
        with set_forward_context(
            attn_metadata=None,  # piecewise no need attn_metadata
            vllm_config=self.vllm_config,
            num_tokens=num_tokens,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            num_tokens_across_dp=num_tokens_across_dp,
            batch_descriptor=batch_descriptor,
            slot_mapping=slot_mappings,
        ):
246
            model(**model_inputs)
247

Woosuk Kwon's avatar
Woosuk Kwon committed
248
249
250
251
    @torch.inference_mode()
    def capture(
        self,
        model: nn.Module,
252
        model_state: ModelState,
Woosuk Kwon's avatar
Woosuk Kwon committed
253
254
        input_buffers: InputBuffers,
        block_tables: BlockTables,
255
        attn_groups: list[list[AttentionGroup]],
Woosuk Kwon's avatar
Woosuk Kwon committed
256
        kv_cache_config: KVCacheConfig,
257
        has_lora: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
258
    ) -> None:
259
260
261
        common_kwargs = dict(
            device=self.device,
            capture_fn=self.capture_graph,
262
            model=model,
263
            model_state=model_state,
264
265
            input_buffers=input_buffers,
            block_tables=block_tables,
266
            attn_groups=attn_groups,
267
            kv_cache_config=kv_cache_config,
268
            has_lora=has_lora,
269
270
        )

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        # Phase 1: Capture for mixed prefill-decode batches if needed.
        mixed_mode = self.cudagraph_mode.mixed_mode()
        if mixed_mode != CUDAGraphMode.NONE:
            capture_graphs(
                cudagraph_sizes=self.cudagraph_sizes,
                capture_cudagraph_mode=mixed_mode,
                desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
                uniform_decode=False,
                **common_kwargs,
            )

        # Phase 2: Capture FULL graphs for uniform decode batches if needed.
        # This is only needed if we use a separate routine for decode batches
        # and the decode_mode is FULL.
        if self.uniform_decode_cudagraph_sizes:
            capture_graphs(
                cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
                capture_cudagraph_mode=CUDAGraphMode.FULL,
                desc="Capturing CUDA graphs (decode, FULL)",
                uniform_decode=True,
                **common_kwargs,
            )

    def get_cudagraph_runtime_mode(
        self, num_reqs: int, num_tokens: int, max_query_len: int
    ) -> tuple[CUDAGraphMode, int | None]:
        is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
            num_tokens == max_query_len * num_reqs
        )

        cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
        if cudagraph_size is None:
            cudagraph_mode = CUDAGraphMode.NONE
        elif is_uniform_decode:
            cudagraph_mode = self.cudagraph_mode.decode_mode()
        else:
            cudagraph_mode = self.cudagraph_mode.mixed_mode()
308
309
310
311
312
313
314
315
316
317

        if (
            cudagraph_mode == CUDAGraphMode.FULL
            and cudagraph_size is not None
            and cudagraph_size not in self.graphs
        ):
            # If graph wasn't captured yet, fall back to eager.
            # This might happen when the dummy run is called before capture.
            cudagraph_mode = CUDAGraphMode.NONE
            cudagraph_size = None
318
319
        return cudagraph_mode, cudagraph_size

320
321
322
    def run_fullgraph(
        self, num_tokens: int
    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
323
        assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
324
325
326
327
328
329
330
        # Sync offloader before replay - needed when transitioning from
        # eager/piecewise to full cudagraph (e.g., prefill → decode).
        # The previous eager iteration's start_prefetch may have queued
        # H2D copies on copy_stream that the graph's captured events
        # cannot see. Without this, replay could overwrite static buffers
        # while those copies are still in flight.
        get_offloader().sync_prev_onload()
331
        self.graphs[num_tokens].replay()
Woosuk Kwon's avatar
Woosuk Kwon committed
332
        assert self.hidden_states is not None
333
334
335
336
        hidden_states = self.hidden_states[:num_tokens]
        if not self.use_aux_hidden_state_outputs:
            return hidden_states
        return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states]
337
338
339
340
341
342
343


def get_cudagraph_sizes(
    capture_sizes: list[int] | None,
    max_num_reqs: int,
    max_num_tokens: int,
    cudagraph_mode: CUDAGraphMode,
344
345
346
347
348
349
    uniform_decode_query_len: int = 1,
    uniform_decode_cudagraph: bool = False,
) -> tuple[dict[int, int], dict[int, int]]:
    # Support both FULL and PIECEWISE cudagraph modes
    if cudagraph_mode == CUDAGraphMode.NONE:
        return {}, {}
350
    if not capture_sizes:
351
        return {}, {}
352
353
354

    capture_sizes = sorted(capture_sizes)
    if not capture_sizes:
355
        return {}, {}
356
357
358
359
360
361
362

    cudagraph_sizes: dict[int, int] = {}
    for i in range(1, capture_sizes[-1] + 1):
        for x in capture_sizes:
            if i <= x:
                cudagraph_sizes[i] = x
                break
363

364
365
366
367
368
369
370
371
372
    uniform_decode_cudagraph_sizes: dict[int, int] = {}
    if uniform_decode_cudagraph:
        max_num_tokens = max_num_reqs * uniform_decode_query_len
        uniform_decode_cudagraph_sizes = {
            k: v
            for k, v in cudagraph_sizes.items()
            if v <= max_num_tokens and v >= uniform_decode_query_len
        }
    return cudagraph_sizes, uniform_decode_cudagraph_sizes
373
374
375
376
377
378


def capture_graphs(
    cudagraph_sizes: dict[int, int],
    device: torch.device,
    capture_fn: Callable,
379
380
    capture_cudagraph_mode: CUDAGraphMode,
    desc: str = "Capturing CUDA graphs",
381
382
383
384
385
    **capture_kwargs,
) -> None:
    # Capture larger graphs first.
    sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
    if is_global_first_rank():
386
        sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
387
388
389

    with graph_capture(device=device):
        for size in sizes_to_capture:
390
            capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
391
392
393
394
395
396
397


def prepare_inputs_to_capture(
    num_reqs: int,
    num_tokens: int,
    input_buffers: InputBuffers,
    block_tables: BlockTables,
398
    attn_groups: list[list[AttentionGroup]],
399
400
    max_model_len: int,
    kv_cache_config: KVCacheConfig,
401
    uniform_decode_query_len: int = 0,
402
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
403
404
405
406
    if uniform_decode_query_len > 0:
        num_tokens_per_req = uniform_decode_query_len
    else:
        num_tokens_per_req = num_tokens // num_reqs
407
408
409
410
411
412
413

    query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
    query_start_loc_np[-1] = num_tokens
    query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
    input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
    input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
    query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
414

415
    # HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
416
    # rather than max_model_len.
417
    input_buffers.seq_lens[:num_reqs] = num_tokens
418
419
    input_buffers.seq_lens[num_reqs:] = 0

420
421
    input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
    input_buffers.dcp_local_seq_lens[num_reqs:] = 0
422

423
424
    input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
    slot_mappings = block_tables.slot_mappings[:, :num_tokens]
425
426
427
    slot_mappings_by_layer = build_slot_mappings_by_layer(
        slot_mappings, kv_cache_config
    )
428
429

    attn_metadata = build_attn_metadata(
430
        attn_groups=attn_groups,
431
432
        num_reqs=num_reqs,
        num_tokens=num_tokens,
433
434
        query_start_loc_gpu=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
435
        max_query_len=num_tokens_per_req,
436
        seq_lens=input_buffers.seq_lens,
437
        max_seq_len=max_model_len,
438
439
440
        block_tables=input_block_tables,
        slot_mappings=slot_mappings,
        kv_cache_config=kv_cache_config,
441
        dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
442
    )
443
    return attn_metadata, slot_mappings_by_layer