cudagraph_utils.py 16.3 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Callable
4
from typing import Any
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11
12
13

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
14
15
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.math_utils import cdiv
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.v1.kv_cache_interface import KVCacheConfig
17
18
19
20
from vllm.v1.worker.gpu.attn_utils import (
    build_attn_metadata,
    build_slot_mappings_by_layer,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm.v1.worker.gpu.block_table import BlockTables
22
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.v1.worker.gpu.input_batch import InputBuffers
24
from vllm.v1.worker.utils import AttentionGroup
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27


class CudaGraphManager:
28
29
30
31
32
33
34
    def __init__(
        self,
        vllm_config: VllmConfig,
        uses_mrope: bool,
        use_aux_hidden_state_outputs: bool,
        device: torch.device,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
35
        self.vllm_config = vllm_config
36
        self.scheduler_config = vllm_config.scheduler_config
37
        self.uses_mrope = uses_mrope
38
        self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
41
        self.device = device

        self.max_model_len = vllm_config.model_config.max_model_len
42
        self.max_num_reqs = self.scheduler_config.max_num_seqs
43
        self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
44
        self.dp_size = vllm_config.parallel_config.data_parallel_size
45
46
47
48
49
50

        self.uniform_decode_query_len = 1
        spec_config = vllm_config.speculative_config
        if spec_config is not None:
            self.uniform_decode_query_len += spec_config.num_speculative_tokens

Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
        self.compilation_config = vllm_config.compilation_config
        assert self.compilation_config is not None
53
        self.cudagraph_mode = self.compilation_config.cudagraph_mode
54
55
56
57
58
59

        use_uniform_decode_cudagraph = (
            self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
            and self.cudagraph_mode.separate_routine()
        )
        self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
60
61
62
63
            self.compilation_config.cudagraph_capture_sizes,
            self.max_num_reqs,
            self.max_num_tokens,
            self.cudagraph_mode,
64
65
            self.uniform_decode_query_len,
            use_uniform_decode_cudagraph,
66
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68

        self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
69
70
71
        self.pool = None
        if self.cudagraph_mode != CUDAGraphMode.NONE:
            self.pool = torch.cuda.graph_pool_handle()
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        self.hidden_states: torch.Tensor | None = None
73
        self.aux_hidden_states: list[torch.Tensor] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
74
75

    def needs_capture(self) -> bool:
76
        return len(self.cudagraph_sizes) > 0
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78

    def get_cudagraph_size(
79
        self, num_tokens: int, uniform_decode: bool = False
Woosuk Kwon's avatar
Woosuk Kwon committed
80
    ) -> int | None:
81
82
83
        if uniform_decode and self.uniform_decode_cudagraph_sizes:
            return self.uniform_decode_cudagraph_sizes.get(num_tokens)
        return self.cudagraph_sizes.get(num_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
86

    def capture_graph(
        self,
87
        num_tokens: int,
88
        capture_cg_mode: CUDAGraphMode,
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
        model: nn.Module,
        input_buffers: InputBuffers,
91
        mrope_positions: torch.Tensor | None,
92
        inputs_embeds: torch.Tensor | None,
Woosuk Kwon's avatar
Woosuk Kwon committed
93
        block_tables: BlockTables,
94
        attn_groups: list[list[AttentionGroup]],
Woosuk Kwon's avatar
Woosuk Kwon committed
95
        kv_cache_config: KVCacheConfig,
96
97
        has_lora: bool = False,
        uniform_decode: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
98
    ) -> None:
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        # select and check capture function
        assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
            f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
        )
        if capture_cg_mode == CUDAGraphMode.PIECEWISE:
            capture_fn = self._capture_piecewise_graph
        else:
            capture_fn = self._capture_full_graph
        # prepare inputs
        if uniform_decode:
            num_reqs = min(
                cdiv(num_tokens, self.uniform_decode_query_len),
                self.max_num_reqs,
            )
        else:
            num_reqs = min(num_tokens, self.max_num_reqs)
115
        input_ids = input_buffers.input_ids[:num_tokens]
116
117
118
119
        positions = input_buffers.positions[:num_tokens]
        if self.uses_mrope:
            assert mrope_positions is not None
            positions = mrope_positions[:, :num_tokens]
120
121
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[:num_tokens]
122
        attn_metadata, slot_mappings = prepare_inputs_to_capture(
123
124
125
126
            num_reqs,
            num_tokens,
            input_buffers,
            block_tables,
127
            attn_groups,
128
129
            self.max_model_len,
            kv_cache_config,
130
131
132
            uniform_decode_query_len=(
                self.uniform_decode_query_len if uniform_decode else 0
            ),
Woosuk Kwon's avatar
Woosuk Kwon committed
133
        )
134
        num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
137
138
139

        # Warm up.
        with set_forward_context(
            attn_metadata,
            self.vllm_config,
140
            num_tokens=num_tokens,
141
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
Woosuk Kwon's avatar
Woosuk Kwon committed
142
            num_tokens_across_dp=num_tokens_across_dp,
143
            slot_mapping=slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
144
        ):
145
            model_output = model(
Woosuk Kwon's avatar
Woosuk Kwon committed
146
147
                input_ids=input_ids,
                positions=positions,
148
                inputs_embeds=inputs_embeds,
Woosuk Kwon's avatar
Woosuk Kwon committed
149
            )
150
151
152
153
154
155
156
157
158
159
160
            if self.use_aux_hidden_state_outputs:
                hidden_states, aux_hidden_states = model_output
            else:
                hidden_states = model_output
                aux_hidden_states = None

        # Allocate output buffers if not already done.
        if self.hidden_states is None:
            self.hidden_states = torch.empty_like(hidden_states)
        if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
            self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states]
Woosuk Kwon's avatar
Woosuk Kwon committed
161

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        capture_fn(
            num_tokens=num_tokens,
            num_reqs=num_reqs,
            model=model,
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            num_tokens_across_dp=num_tokens_across_dp,
            attn_metadata=attn_metadata,
            slot_mappings=slot_mappings,
            has_lora=has_lora,
        )

    def _capture_full_graph(
        self,
        num_tokens: int,
        num_reqs: int,
        model: nn.Module,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        inputs_embeds: torch.Tensor | None,
        num_tokens_across_dp: torch.Tensor,
        attn_metadata: dict[str, Any] | None,
        slot_mappings: dict[str, torch.Tensor] | None,
        has_lora: bool = False,
    ) -> None:
        assert attn_metadata is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
189
        # Capture the graph.
190
        assert num_tokens not in self.graphs
Woosuk Kwon's avatar
Woosuk Kwon committed
191
192
193
        graph = torch.cuda.CUDAGraph()
        with (
            set_forward_context(
194
195
                attn_metadata=attn_metadata,
                vllm_config=self.vllm_config,
196
                num_tokens=num_tokens,
197
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
Woosuk Kwon's avatar
Woosuk Kwon committed
198
                num_tokens_across_dp=num_tokens_across_dp,
199
                slot_mapping=slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
202
            ),
            torch.cuda.graph(graph, self.pool),
        ):
203
            model_output = model(
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
                input_ids=input_ids,
                positions=positions,
206
                inputs_embeds=inputs_embeds,
Woosuk Kwon's avatar
Woosuk Kwon committed
207
            )
208
209
210
211
212
213
214
            if self.use_aux_hidden_state_outputs:
                hidden_states, aux_hidden_states = model_output
            else:
                hidden_states = model_output
                aux_hidden_states = None

            # Copy outputs to the output buffers.
215
            assert self.hidden_states is not None
216
            self.hidden_states[:num_tokens] = hidden_states
217
218
219
            if self.use_aux_hidden_state_outputs:
                for i, aux_hidden in enumerate(aux_hidden_states):
                    self.aux_hidden_states[i][:num_tokens] = aux_hidden
220
        self.graphs[num_tokens] = graph
Woosuk Kwon's avatar
Woosuk Kwon committed
221

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    def _capture_piecewise_graph(
        self,
        num_tokens: int,
        num_reqs: int,
        model: nn.Module,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        inputs_embeds: torch.Tensor | None,
        num_tokens_across_dp: torch.Tensor,
        attn_metadata: dict[str, Any] | None,
        slot_mappings: dict[str, torch.Tensor] | None,
        has_lora: bool = False,
    ) -> None:
        # create batch descriptor for piecewise cudagraph dispatch key
        batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)

        # Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
        with set_forward_context(
            attn_metadata=None,  # piecewise no need attn_metadata
            vllm_config=self.vllm_config,
            num_tokens=num_tokens,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            num_tokens_across_dp=num_tokens_across_dp,
            batch_descriptor=batch_descriptor,
            slot_mapping=slot_mappings,
        ):
248
            model(
249
250
251
252
253
                input_ids=input_ids,
                positions=positions,
                inputs_embeds=inputs_embeds,
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
254
255
256
257
258
    @torch.inference_mode()
    def capture(
        self,
        model: nn.Module,
        input_buffers: InputBuffers,
259
        mrope_positions: torch.Tensor | None,
260
        inputs_embeds: torch.Tensor | None,
Woosuk Kwon's avatar
Woosuk Kwon committed
261
        block_tables: BlockTables,
262
        attn_groups: list[list[AttentionGroup]],
Woosuk Kwon's avatar
Woosuk Kwon committed
263
        kv_cache_config: KVCacheConfig,
264
        has_lora: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
265
    ) -> None:
266
267
268
        common_kwargs = dict(
            device=self.device,
            capture_fn=self.capture_graph,
269
270
            model=model,
            input_buffers=input_buffers,
271
            mrope_positions=mrope_positions,
272
            inputs_embeds=inputs_embeds,
273
            block_tables=block_tables,
274
            attn_groups=attn_groups,
275
            kv_cache_config=kv_cache_config,
276
            has_lora=has_lora,
277
278
        )

279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
        # Phase 1: Capture for mixed prefill-decode batches if needed.
        mixed_mode = self.cudagraph_mode.mixed_mode()
        if mixed_mode != CUDAGraphMode.NONE:
            capture_graphs(
                cudagraph_sizes=self.cudagraph_sizes,
                capture_cudagraph_mode=mixed_mode,
                desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
                uniform_decode=False,
                **common_kwargs,
            )

        # Phase 2: Capture FULL graphs for uniform decode batches if needed.
        # This is only needed if we use a separate routine for decode batches
        # and the decode_mode is FULL.
        if self.uniform_decode_cudagraph_sizes:
            capture_graphs(
                cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
                capture_cudagraph_mode=CUDAGraphMode.FULL,
                desc="Capturing CUDA graphs (decode, FULL)",
                uniform_decode=True,
                **common_kwargs,
            )

    def get_cudagraph_runtime_mode(
        self, num_reqs: int, num_tokens: int, max_query_len: int
    ) -> tuple[CUDAGraphMode, int | None]:
        is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
            num_tokens == max_query_len * num_reqs
        )

        cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
        if cudagraph_size is None:
            cudagraph_mode = CUDAGraphMode.NONE
        elif is_uniform_decode:
            cudagraph_mode = self.cudagraph_mode.decode_mode()
        else:
            cudagraph_mode = self.cudagraph_mode.mixed_mode()
316
317
318
319
320
321
322
323
324
325

        if (
            cudagraph_mode == CUDAGraphMode.FULL
            and cudagraph_size is not None
            and cudagraph_size not in self.graphs
        ):
            # If graph wasn't captured yet, fall back to eager.
            # This might happen when the dummy run is called before capture.
            cudagraph_mode = CUDAGraphMode.NONE
            cudagraph_size = None
326
327
        return cudagraph_mode, cudagraph_size

328
329
330
    def run_fullgraph(
        self, num_tokens: int
    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
331
        assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
332
        self.graphs[num_tokens].replay()
Woosuk Kwon's avatar
Woosuk Kwon committed
333
        assert self.hidden_states is not None
334
335
336
337
        hidden_states = self.hidden_states[:num_tokens]
        if not self.use_aux_hidden_state_outputs:
            return hidden_states
        return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states]
338
339
340
341
342
343
344


def get_cudagraph_sizes(
    capture_sizes: list[int] | None,
    max_num_reqs: int,
    max_num_tokens: int,
    cudagraph_mode: CUDAGraphMode,
345
346
347
348
349
350
    uniform_decode_query_len: int = 1,
    uniform_decode_cudagraph: bool = False,
) -> tuple[dict[int, int], dict[int, int]]:
    # Support both FULL and PIECEWISE cudagraph modes
    if cudagraph_mode == CUDAGraphMode.NONE:
        return {}, {}
351
    if not capture_sizes:
352
        return {}, {}
353
354
355

    capture_sizes = sorted(capture_sizes)
    if not capture_sizes:
356
        return {}, {}
357
358
359
360
361
362
363

    cudagraph_sizes: dict[int, int] = {}
    for i in range(1, capture_sizes[-1] + 1):
        for x in capture_sizes:
            if i <= x:
                cudagraph_sizes[i] = x
                break
364

365
366
367
368
369
370
371
372
373
    uniform_decode_cudagraph_sizes: dict[int, int] = {}
    if uniform_decode_cudagraph:
        max_num_tokens = max_num_reqs * uniform_decode_query_len
        uniform_decode_cudagraph_sizes = {
            k: v
            for k, v in cudagraph_sizes.items()
            if v <= max_num_tokens and v >= uniform_decode_query_len
        }
    return cudagraph_sizes, uniform_decode_cudagraph_sizes
374
375
376
377
378
379


def capture_graphs(
    cudagraph_sizes: dict[int, int],
    device: torch.device,
    capture_fn: Callable,
380
381
    capture_cudagraph_mode: CUDAGraphMode,
    desc: str = "Capturing CUDA graphs",
382
383
384
385
386
    **capture_kwargs,
) -> None:
    # Capture larger graphs first.
    sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
    if is_global_first_rank():
387
        sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
388
389
390

    with graph_capture(device=device):
        for size in sizes_to_capture:
391
            capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
392
393
394
395
396
397
398


def prepare_inputs_to_capture(
    num_reqs: int,
    num_tokens: int,
    input_buffers: InputBuffers,
    block_tables: BlockTables,
399
    attn_groups: list[list[AttentionGroup]],
400
401
    max_model_len: int,
    kv_cache_config: KVCacheConfig,
402
    uniform_decode_query_len: int = 0,
403
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
404
405
406
407
    if uniform_decode_query_len > 0:
        num_tokens_per_req = uniform_decode_query_len
    else:
        num_tokens_per_req = num_tokens // num_reqs
408
409
410
411
412
413
414

    query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
    query_start_loc_np[-1] = num_tokens
    query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
    input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
    input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
    query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
415

416
    # HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
417
    # rather than max_model_len.
418
    input_buffers.seq_lens[:num_reqs] = num_tokens
419
420
    input_buffers.seq_lens[num_reqs:] = 0

421
422
    input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
    input_buffers.dcp_local_seq_lens[num_reqs:] = 0
423

424
425
    input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
    slot_mappings = block_tables.slot_mappings[:, :num_tokens]
426
427
428
    slot_mappings_by_layer = build_slot_mappings_by_layer(
        slot_mappings, kv_cache_config
    )
429
430

    attn_metadata = build_attn_metadata(
431
        attn_groups=attn_groups,
432
433
        num_reqs=num_reqs,
        num_tokens=num_tokens,
434
435
        query_start_loc_gpu=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
436
        max_query_len=num_tokens_per_req,
437
        seq_lens=input_buffers.seq_lens,
438
        max_seq_len=max_model_len,
439
440
441
        block_tables=input_block_tables,
        slot_mappings=slot_mappings,
        kv_cache_config=kv_cache_config,
442
        dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
443
    )
444
    return attn_metadata, slot_mappings_by_layer