cudagraph_utils.py 17.2 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections import defaultdict
4
from collections.abc import Callable
5
from dataclasses import dataclass
6
from typing import Any
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9
10
11
12
13

import torch
import torch.nn as nn
from tqdm import tqdm

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
14
15
16
17
18
from vllm.distributed.parallel_state import (
    get_pp_group,
    graph_capture,
    is_global_first_rank,
)
19
from vllm.forward_context import BatchDescriptor, set_forward_context
20
from vllm.logger import init_logger
21
from vllm.model_executor.offloader.base import get_offloader
22
from vllm.platforms import current_platform
23
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from vllm.v1.kv_cache_interface import KVCacheConfig
25
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
Woosuk Kwon's avatar
Woosuk Kwon committed
26
from vllm.v1.worker.gpu.block_table import BlockTables
27
28
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
29
from vllm.v1.worker.gpu.model_states.interface import ModelState
30
from vllm.v1.worker.utils import AttentionGroup
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
logger = init_logger(__name__)


@dataclass(frozen=True)
class BatchExecutionDescriptor:
    """Describes the shape of the batch and CG mode to run; this is used to make shape
    matches between the capture and runtime."""

    cg_mode: CUDAGraphMode
    num_tokens: int
    num_reqs: int | None  # None means no request padding is needed (PIECEWISE graphs)
    uniform_token_count: int | None = None


def _is_compatible(
    desc: BatchExecutionDescriptor,
    num_reqs: int,
    num_tokens: int,
    uniform_token_count: int | None,
) -> bool:
    # desc.uniform_token_count=None (PIECEWISE) can handle any uniform_token_count
    # desc.num_reqs=None means no request padding needed (PIECEWISE)
    return (
        (
            desc.uniform_token_count is None
            or desc.uniform_token_count == uniform_token_count
        )
        and (desc.num_reqs is None or desc.num_reqs >= num_reqs)
        and desc.num_tokens >= num_tokens
    )


def get_uniform_token_count(
    num_reqs: int,
    num_tokens: int,
    max_query_len: int,
) -> int | None:
    """
    Return the uniform token count if batch is uniform, else None.
    A batch is uniform if all requests have the same number of tokens.
    """
    if (max_query_len == num_tokens // num_reqs) and (
        num_tokens == max_query_len * num_reqs
    ):
        return max_query_len
    return None

Woosuk Kwon's avatar
Woosuk Kwon committed
79
80

class CudaGraphManager:
81
82
83
84
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
85
86
        cudagraph_mode: CUDAGraphMode,
        decode_query_len: int,
87
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
        self.vllm_config = vllm_config
        self.device = device
90
91
92
93
94
        self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
        self.compilation_config = vllm_config.compilation_config
        assert self.compilation_config is not None
        self.cudagraph_mode = cudagraph_mode
        self.decode_query_len = decode_query_len
95

Woosuk Kwon's avatar
Woosuk Kwon committed
96
        self.dp_size = vllm_config.parallel_config.data_parallel_size
97
98
        self.is_first_pp_rank = get_pp_group().is_first_rank
        self.is_last_pp_rank = get_pp_group().is_last_rank
99

100
101
        self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
        self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None
102

103
104
105
106
        self._graphs_captured = False
        self._candidates: list[list[BatchExecutionDescriptor]] = []
        self._capture_descs: dict[CUDAGraphMode, list[BatchExecutionDescriptor]] = {}
        self._init_candidates()
107

108
109
110
111
112
    def _init_candidates(self) -> None:
        """Build priority-ordered candidate lists for each token count."""
        capture_sizes = self.compilation_config.cudagraph_capture_sizes
        if not (self.cudagraph_mode and capture_sizes):
            return
Woosuk Kwon's avatar
Woosuk Kwon committed
113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        capture_sizes = sorted(capture_sizes)
        max_decode_tokens = self.max_num_reqs * self.decode_query_len
        decode_mode = self.cudagraph_mode.decode_mode()
        mixed_mode = self.cudagraph_mode.mixed_mode()
        separate_decode_routine = self.cudagraph_mode.separate_routine()

        descs_by_token_count = defaultdict(list)
        descs_by_mode = defaultdict(list)

        for num_tokens in capture_sizes:
            # Capture uniform decode specfifc graphs if required
            #  (i.e. separate decode routine)
            if (
                separate_decode_routine
                and decode_mode
                and self.decode_query_len <= num_tokens <= max_decode_tokens
            ):
                desc = BatchExecutionDescriptor(
                    cg_mode=decode_mode,
                    num_tokens=num_tokens,
                    num_reqs=num_tokens // self.decode_query_len,
                    uniform_token_count=self.decode_query_len,
                )
                descs_by_mode[decode_mode].append(desc)
                descs_by_token_count[num_tokens].append(desc)

            if mixed_mode:
                # for PIECEWISE graphs there is no limit on requests when replaying
                # i.e. no request padding is needed
                # so we leave it as None
                num_reqs = (
                    min(num_tokens, self.max_num_reqs)
                    if mixed_mode == CUDAGraphMode.FULL
                    else None
                )
                desc = BatchExecutionDescriptor(
                    cg_mode=mixed_mode,
                    num_tokens=num_tokens,
                    num_reqs=num_reqs,
                )
                descs_by_mode[mixed_mode].append(desc)
                descs_by_token_count[num_tokens].append(desc)

        if not descs_by_token_count:
            return

        sorted_padded = sorted(descs_by_token_count.keys())
        self._candidates = [[] for _ in range(sorted_padded[-1] + 1)]

        current_range_start = 0
        for cg_size in sorted_padded:
            for i in range(current_range_start, cg_size + 1):
                self._candidates[i] = descs_by_token_count[cg_size]
            current_range_start = cg_size + 1

        for mode, descs in descs_by_mode.items():
            descs.sort(key=lambda d: d.num_tokens, reverse=True)
            self._capture_descs[mode] = descs
Woosuk Kwon's avatar
Woosuk Kwon committed
172
173

    def needs_capture(self) -> bool:
174
        return len(self._capture_descs) > 0
Woosuk Kwon's avatar
Woosuk Kwon committed
175

176
177
    @torch.inference_mode()
    def capture(
Woosuk Kwon's avatar
Woosuk Kwon committed
178
        self,
179
180
181
182
        create_forward_fn: Callable[
            [BatchExecutionDescriptor], Callable[[CUDAGraphMode], None]
        ],
        progress_bar_desc: str = "Capturing CUDA graphs",
Woosuk Kwon's avatar
Woosuk Kwon committed
183
    ) -> None:
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        """Capture CUDA graphs.

        Args:
            create_forward_fn: Factory that prepares inputs (OUTSIDE graph) and
                returns a function that runs forward with a given CUDAGraphMode.
        """
        with graph_capture(device=self.device):
            # Capture in order: PIECEWISE first, then FULL. PIECEWISE has larger
            # activations so FULL activations should fit in already allocated
            # buffers in the graph pool.
            for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
                if mode not in self._capture_descs:
                    continue

                descs = self._capture_descs[mode]
                if is_global_first_rank():
                    descs = tqdm(descs, desc=f"{progress_bar_desc} ({mode.name})")
                for desc in descs:
                    # Prepare inputs and get forward function
                    forward_fn = create_forward_fn(desc)

                    # Warmup
                    forward_fn(CUDAGraphMode.NONE)

                    # Capture
                    logger.debug(
                        "CG Capture: mode=%s, batch_desc=%s", desc.cg_mode.name, desc
                    )
                    if desc.cg_mode == CUDAGraphMode.PIECEWISE:
                        forward_fn(CUDAGraphMode.PIECEWISE)
                    else:
                        assert desc not in self.graphs, (
                            f"Graph already captured for {desc}"
                        )
                        graph = torch.cuda.CUDAGraph()
                        # Sync offloader's copy stream before capture.
                        # Ensure any pre-capture prefetches from offloader are complete.
                        get_offloader().sync_prev_onload()
                        with torch.cuda.graph(graph, self.pool):
                            forward_fn(CUDAGraphMode.NONE)
                            # Join offloader's copy stream after forward to avoid
                            # unjoined stream error. The last layer's start_prefetch
                            # forks copy_stream, but wait_prefetch only happens in
                            # the next forward pass.
                            get_offloader().join_after_forward()
                        self.graphs[desc] = graph
        self._graphs_captured = True

    def dispatch(
233
234
        self,
        num_reqs: int,
235
236
237
238
239
240
241
242
243
244
245
        num_tokens: int,
        uniform_token_count: int | None,
    ) -> BatchExecutionDescriptor:
        """Find matching cudagraph descriptor from priority-ordered candidates."""
        if self._graphs_captured and 0 < num_tokens < len(self._candidates):
            for desc in self._candidates[num_tokens]:
                if _is_compatible(desc, num_reqs, num_tokens, uniform_token_count):
                    return desc
        return BatchExecutionDescriptor(
            cg_mode=CUDAGraphMode.NONE, num_tokens=num_tokens, num_reqs=num_reqs
        )
246

247
248
249
250
251
252
253
254
255
256
257
258
    def run_fullgraph(self, desc: BatchExecutionDescriptor):
        """Replay a captured FULL cudagraph."""
        assert desc.cg_mode == CUDAGraphMode.FULL, (
            f"Expected FULL mode, got {desc.cg_mode}"
        )
        assert desc in self.graphs, f"No cudagraph for {desc}"
        # Sync offloader before replay - needed when transitioning from
        # eager/piecewise to full cudagraph (e.g., prefill → decode).
        # The previous eager iteration's start_prefetch may have queued
        # H2D copies on copy_stream that the graph's captured events
        # cannot see. Without this, replay could overwrite static buffers
        # while those copies are still in flight.
259
        get_offloader().sync_prev_onload()
260
261
262
263
264
        self.graphs[desc].replay()


class ModelCudaGraphManager(CudaGraphManager):
    """CudaGraphManager with model-specific capture and hidden state management."""
265

266
    def __init__(
267
        self,
268
269
270
271
272
273
        vllm_config: VllmConfig,
        device: torch.device,
        cudagraph_mode: CUDAGraphMode,
        decode_query_len: int,
    ):
        super().__init__(vllm_config, device, cudagraph_mode, decode_query_len)
274
        # Used for FULL CUDA graphs. PW CUDA graphs do not use these.
275
276
277
        self.hidden_states: torch.Tensor | None = None
        self.aux_hidden_states: list[torch.Tensor] = []
        self.use_aux_hidden_state_outputs = False
278
        self.intermediate_tensors: IntermediateTensors | None = None
279

Woosuk Kwon's avatar
Woosuk Kwon committed
280
281
282
    def capture(
        self,
        model: nn.Module,
283
        model_state: ModelState,
Woosuk Kwon's avatar
Woosuk Kwon committed
284
        input_buffers: InputBuffers,
285
        intermediate_tensors: IntermediateTensors | None,
Woosuk Kwon's avatar
Woosuk Kwon committed
286
        block_tables: BlockTables,
287
        attn_groups: list[list[AttentionGroup]],
Woosuk Kwon's avatar
Woosuk Kwon committed
288
        kv_cache_config: KVCacheConfig,
289
        has_lora: bool = False,
290
291
        use_aux_hidden_state_outputs: bool = False,
        progress_bar_desc: str = "Capturing CUDA graphs",
Woosuk Kwon's avatar
Woosuk Kwon committed
292
    ) -> None:
293
294
        """Capture CUDA graphs for model forward pass."""
        self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
295

296
297
298
299
300
301
302
303
304
        def create_forward_fn(
            desc: BatchExecutionDescriptor,
        ) -> Callable[[CUDAGraphMode], None]:
            num_tokens = desc.num_tokens
            num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
            num_tokens_across_dp = (
                torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
                if self.dp_size > 1
                else None
305
            )
306
307
308
309
310
311
312
313
314
315
316
317
318

            model_inputs = {
                "input_ids": input_buffers.input_ids[:num_tokens],
                "positions": input_buffers.positions[:num_tokens],
                **model_state.prepare_dummy_inputs(num_reqs, num_tokens),
            }
            if not self.is_first_pp_rank:
                # Update for non-first PP ranks.
                model_inputs["input_ids"] = None
                model_inputs["inputs_embeds"] = None
                assert intermediate_tensors is not None
                model_inputs["intermediate_tensors"] = intermediate_tensors[:num_tokens]

319
320
321
322
323
324
325
326
            attn_metadata, slot_mappings = prepare_inputs_to_capture(
                num_reqs,
                num_tokens,
                model_state,
                input_buffers,
                block_tables,
                attn_groups,
                kv_cache_config,
327
328
            )

329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
            def forward_fn(cg_mode: CUDAGraphMode) -> None:
                batch_descriptor = (
                    BatchDescriptor(num_tokens=num_tokens)
                    if cg_mode == CUDAGraphMode.PIECEWISE
                    else None
                )
                with set_forward_context(
                    attn_metadata if cg_mode != CUDAGraphMode.PIECEWISE else None,
                    self.vllm_config,
                    num_tokens=num_tokens,
                    cudagraph_runtime_mode=cg_mode,
                    num_tokens_across_dp=num_tokens_across_dp,
                    slot_mapping=slot_mappings,
                    batch_descriptor=batch_descriptor,
                ):
                    model_output = model(**model_inputs)
345

346
347
348
349
                if cg_mode == CUDAGraphMode.PIECEWISE:
                    # PW CUDA graph internally handles the model outputs.
                    # No need to keep track of the hidden states.
                    return None
350

351
352
                if self.is_last_pp_rank:
                    # Last PP rank (common case).
353
354
355
356
357
358
359
                    if self.use_aux_hidden_state_outputs:
                        hidden_states, aux_hidden_states = model_output
                    else:
                        hidden_states = model_output
                        aux_hidden_states = []
                    if self.hidden_states is None:
                        self.hidden_states = torch.empty_like(hidden_states)
360
                    self.hidden_states[:num_tokens] = hidden_states
361
362
363
364
365
366
                    if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
                        self.aux_hidden_states = [
                            torch.empty_like(x) for x in aux_hidden_states
                        ]
                    for i, aux in enumerate(aux_hidden_states):
                        self.aux_hidden_states[i][:num_tokens] = aux
367
368
                else:
                    # Non-last PP rank.
369
                    assert isinstance(model_output, IntermediateTensors)
370
371
                    intermediate_tensors = model_output
                    if self.intermediate_tensors is None:
372
373
                        self.intermediate_tensors = IntermediateTensors.empty_like(
                            intermediate_tensors
374
375
376
                        )
                    for k, v in intermediate_tensors.tensors.items():
                        self.intermediate_tensors[k][:num_tokens] = v
377
378
379
380

            return forward_fn

        super().capture(create_forward_fn, progress_bar_desc)
381

382
    def run_fullgraph(
383
        self, desc: BatchExecutionDescriptor
384
    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]] | IntermediateTensors:
385
386
        """Replay a captured FULL cudagraph and return hidden states."""
        super().run_fullgraph(desc)
387
388
389
390
        if not self.is_last_pp_rank:
            assert self.intermediate_tensors is not None
            return self.intermediate_tensors[: desc.num_tokens]

Woosuk Kwon's avatar
Woosuk Kwon committed
391
        assert self.hidden_states is not None
392
        hidden_states = self.hidden_states[: desc.num_tokens]
393
394
        if not self.use_aux_hidden_state_outputs:
            return hidden_states
395
        return hidden_states, [x[: desc.num_tokens] for x in self.aux_hidden_states]
396
397
398
399
400


def prepare_inputs_to_capture(
    num_reqs: int,
    num_tokens: int,
401
    model_state: ModelState,
402
403
    input_buffers: InputBuffers,
    block_tables: BlockTables,
404
    attn_groups: list[list[AttentionGroup]],
405
    kv_cache_config: KVCacheConfig,
406
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
407
    input_batch = InputBatch.make_dummy(num_reqs, num_tokens, input_buffers)
408
409
    input_block_tables = block_tables.get_dummy_block_tables(num_reqs)
    slot_mappings = block_tables.get_dummy_slot_mappings(num_tokens)
410
411
412
    slot_mappings_by_layer = build_slot_mappings_by_layer(
        slot_mappings, kv_cache_config
    )
413

414
415
416
417
418
419
420
421
422
423
424
425
426
427
    # HACK(woosuk): Special handling for DCP.
    if block_tables.cp_size > 1:
        prepare_dcp_local_seq_lens(
            input_buffers.dcp_local_seq_lens,
            input_batch.seq_lens,
            num_reqs,
            block_tables.cp_size,
            block_tables.cp_rank,
            block_tables.cp_interleave,
        )
        input_batch.dcp_local_seq_lens = input_buffers.dcp_local_seq_lens[:num_reqs]

    attn_metadata = model_state.prepare_attn(
        input_batch,
428
        CUDAGraphMode.NONE,
429
430
431
432
        input_block_tables,
        slot_mappings,
        attn_groups,
        kv_cache_config,
433
        for_capture=True,
434
    )
435
    return attn_metadata, slot_mappings_by_layer