cudagraph_utils.py 14.7 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Callable
4
from typing import Any
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11
12
13

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
14
15
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.math_utils import cdiv
16
from vllm.v1.attention.backend import AttentionMetadataBuilder
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.v1.kv_cache_interface import KVCacheConfig
18
19
20
21
from vllm.v1.worker.gpu.attn_utils import (
    build_attn_metadata,
    build_slot_mappings_by_layer,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm.v1.worker.gpu.block_table import BlockTables
23
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
27
from vllm.v1.worker.gpu.input_batch import InputBuffers


class CudaGraphManager:
28
    def __init__(self, vllm_config: VllmConfig, uses_mrope: bool, device: torch.device):
Woosuk Kwon's avatar
Woosuk Kwon committed
29
        self.vllm_config = vllm_config
30
        self.scheduler_config = vllm_config.scheduler_config
31
        self.uses_mrope = uses_mrope
Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
34
        self.device = device

        self.max_model_len = vllm_config.model_config.max_model_len
35
        self.max_num_reqs = self.scheduler_config.max_num_seqs
36
        self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
37
        self.dp_size = vllm_config.parallel_config.data_parallel_size
38
39
40
41
42
43

        self.uniform_decode_query_len = 1
        spec_config = vllm_config.speculative_config
        if spec_config is not None:
            self.uniform_decode_query_len += spec_config.num_speculative_tokens

Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
        self.compilation_config = vllm_config.compilation_config
        assert self.compilation_config is not None
46
        self.cudagraph_mode = self.compilation_config.cudagraph_mode
47
48
49
50
51
52

        use_uniform_decode_cudagraph = (
            self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
            and self.cudagraph_mode.separate_routine()
        )
        self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
53
54
55
56
            self.compilation_config.cudagraph_capture_sizes,
            self.max_num_reqs,
            self.max_num_tokens,
            self.cudagraph_mode,
57
58
            self.uniform_decode_query_len,
            use_uniform_decode_cudagraph,
59
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61

        self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
62
63
64
        self.pool = None
        if self.cudagraph_mode != CUDAGraphMode.NONE:
            self.pool = torch.cuda.graph_pool_handle()
Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
67
        self.hidden_states: torch.Tensor | None = None

    def needs_capture(self) -> bool:
68
        return len(self.cudagraph_sizes) > 0
Woosuk Kwon's avatar
Woosuk Kwon committed
69
70

    def get_cudagraph_size(
71
        self, num_tokens: int, uniform_decode: bool = False
Woosuk Kwon's avatar
Woosuk Kwon committed
72
    ) -> int | None:
73
74
75
        if uniform_decode and self.uniform_decode_cudagraph_sizes:
            return self.uniform_decode_cudagraph_sizes.get(num_tokens)
        return self.cudagraph_sizes.get(num_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78

    def capture_graph(
        self,
79
        num_tokens: int,
80
        capture_cg_mode: CUDAGraphMode,
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
        model: nn.Module,
        input_buffers: InputBuffers,
83
        mrope_positions: torch.Tensor | None,
84
        inputs_embeds: torch.Tensor | None,
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
        block_tables: BlockTables,
        attn_metadata_builders: list[AttentionMetadataBuilder],
        kv_cache_config: KVCacheConfig,
88
89
        has_lora: bool = False,
        uniform_decode: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
90
    ) -> None:
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        # select and check capture function
        assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
            f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
        )
        if capture_cg_mode == CUDAGraphMode.PIECEWISE:
            capture_fn = self._capture_piecewise_graph
        else:
            capture_fn = self._capture_full_graph
        # prepare inputs
        if uniform_decode:
            num_reqs = min(
                cdiv(num_tokens, self.uniform_decode_query_len),
                self.max_num_reqs,
            )
        else:
            num_reqs = min(num_tokens, self.max_num_reqs)
107
        input_ids = input_buffers.input_ids[:num_tokens]
108
109
110
111
        positions = input_buffers.positions[:num_tokens]
        if self.uses_mrope:
            assert mrope_positions is not None
            positions = mrope_positions[:, :num_tokens]
112
113
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[:num_tokens]
114
        attn_metadata, slot_mappings = prepare_inputs_to_capture(
115
116
117
118
119
120
121
            num_reqs,
            num_tokens,
            input_buffers,
            block_tables,
            attn_metadata_builders,
            self.max_model_len,
            kv_cache_config,
122
123
124
            uniform_decode_query_len=(
                self.uniform_decode_query_len if uniform_decode else 0
            ),
Woosuk Kwon's avatar
Woosuk Kwon committed
125
        )
126
        num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
127
128
129
130
131

        # Warm up.
        with set_forward_context(
            attn_metadata,
            self.vllm_config,
132
            num_tokens=num_tokens,
133
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
Woosuk Kwon's avatar
Woosuk Kwon committed
134
            num_tokens_across_dp=num_tokens_across_dp,
135
            slot_mapping=slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
139
        ):
            hidden_states = model(
                input_ids=input_ids,
                positions=positions,
140
                inputs_embeds=inputs_embeds,
Woosuk Kwon's avatar
Woosuk Kwon committed
141
142
143
144
            )
            if self.hidden_states is None:
                self.hidden_states = torch.empty_like(hidden_states)

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        capture_fn(
            num_tokens=num_tokens,
            num_reqs=num_reqs,
            model=model,
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            num_tokens_across_dp=num_tokens_across_dp,
            attn_metadata=attn_metadata,
            slot_mappings=slot_mappings,
            has_lora=has_lora,
        )

    def _capture_full_graph(
        self,
        num_tokens: int,
        num_reqs: int,
        model: nn.Module,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        inputs_embeds: torch.Tensor | None,
        num_tokens_across_dp: torch.Tensor,
        attn_metadata: dict[str, Any] | None,
        slot_mappings: dict[str, torch.Tensor] | None,
        has_lora: bool = False,
    ) -> None:
        assert attn_metadata is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
172
        # Capture the graph.
173
        assert num_tokens not in self.graphs
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176
        graph = torch.cuda.CUDAGraph()
        with (
            set_forward_context(
177
178
                attn_metadata=attn_metadata,
                vllm_config=self.vllm_config,
179
                num_tokens=num_tokens,
180
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
Woosuk Kwon's avatar
Woosuk Kwon committed
181
                num_tokens_across_dp=num_tokens_across_dp,
182
                slot_mapping=slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
183
184
185
186
187
188
            ),
            torch.cuda.graph(graph, self.pool),
        ):
            hidden_states = model(
                input_ids=input_ids,
                positions=positions,
189
                inputs_embeds=inputs_embeds,
Woosuk Kwon's avatar
Woosuk Kwon committed
190
            )
191
            assert self.hidden_states is not None
192
193
            self.hidden_states[:num_tokens] = hidden_states
        self.graphs[num_tokens] = graph
Woosuk Kwon's avatar
Woosuk Kwon committed
194

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    def _capture_piecewise_graph(
        self,
        num_tokens: int,
        num_reqs: int,
        model: nn.Module,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        inputs_embeds: torch.Tensor | None,
        num_tokens_across_dp: torch.Tensor,
        attn_metadata: dict[str, Any] | None,
        slot_mappings: dict[str, torch.Tensor] | None,
        has_lora: bool = False,
    ) -> None:
        # create batch descriptor for piecewise cudagraph dispatch key
        batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)

        # Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
        with set_forward_context(
            attn_metadata=None,  # piecewise no need attn_metadata
            vllm_config=self.vllm_config,
            num_tokens=num_tokens,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            num_tokens_across_dp=num_tokens_across_dp,
            batch_descriptor=batch_descriptor,
            slot_mapping=slot_mappings,
        ):
221
            model(
222
223
224
225
226
                input_ids=input_ids,
                positions=positions,
                inputs_embeds=inputs_embeds,
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
227
228
229
230
231
    @torch.inference_mode()
    def capture(
        self,
        model: nn.Module,
        input_buffers: InputBuffers,
232
        mrope_positions: torch.Tensor | None,
233
        inputs_embeds: torch.Tensor | None,
Woosuk Kwon's avatar
Woosuk Kwon committed
234
235
236
        block_tables: BlockTables,
        attn_metadata_builders: list[AttentionMetadataBuilder],
        kv_cache_config: KVCacheConfig,
237
        has_lora: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
238
    ) -> None:
239
240
241
        common_kwargs = dict(
            device=self.device,
            capture_fn=self.capture_graph,
242
243
            model=model,
            input_buffers=input_buffers,
244
            mrope_positions=mrope_positions,
245
            inputs_embeds=inputs_embeds,
246
247
248
            block_tables=block_tables,
            attn_metadata_builders=attn_metadata_builders,
            kv_cache_config=kv_cache_config,
249
            has_lora=has_lora,
250
251
        )

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
        # Phase 1: Capture for mixed prefill-decode batches if needed.
        mixed_mode = self.cudagraph_mode.mixed_mode()
        if mixed_mode != CUDAGraphMode.NONE:
            capture_graphs(
                cudagraph_sizes=self.cudagraph_sizes,
                capture_cudagraph_mode=mixed_mode,
                desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
                uniform_decode=False,
                **common_kwargs,
            )

        # Phase 2: Capture FULL graphs for uniform decode batches if needed.
        # This is only needed if we use a separate routine for decode batches
        # and the decode_mode is FULL.
        if self.uniform_decode_cudagraph_sizes:
            capture_graphs(
                cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
                capture_cudagraph_mode=CUDAGraphMode.FULL,
                desc="Capturing CUDA graphs (decode, FULL)",
                uniform_decode=True,
                **common_kwargs,
            )

    def get_cudagraph_runtime_mode(
        self, num_reqs: int, num_tokens: int, max_query_len: int
    ) -> tuple[CUDAGraphMode, int | None]:
        is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
            num_tokens == max_query_len * num_reqs
        )

        cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
        if cudagraph_size is None:
            cudagraph_mode = CUDAGraphMode.NONE
        elif is_uniform_decode:
            cudagraph_mode = self.cudagraph_mode.decode_mode()
        else:
            cudagraph_mode = self.cudagraph_mode.mixed_mode()
        return cudagraph_mode, cudagraph_size

    def run_fullgraph(self, num_tokens: int) -> torch.Tensor:
        assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
293
        self.graphs[num_tokens].replay()
Woosuk Kwon's avatar
Woosuk Kwon committed
294
        assert self.hidden_states is not None
295
296
297
298
299
300
301
302
        return self.hidden_states[:num_tokens]


def get_cudagraph_sizes(
    capture_sizes: list[int] | None,
    max_num_reqs: int,
    max_num_tokens: int,
    cudagraph_mode: CUDAGraphMode,
303
304
305
306
307
308
    uniform_decode_query_len: int = 1,
    uniform_decode_cudagraph: bool = False,
) -> tuple[dict[int, int], dict[int, int]]:
    # Support both FULL and PIECEWISE cudagraph modes
    if cudagraph_mode == CUDAGraphMode.NONE:
        return {}, {}
309
    if not capture_sizes:
310
        return {}, {}
311
312
313

    capture_sizes = sorted(capture_sizes)
    if not capture_sizes:
314
        return {}, {}
315
316
317
318
319
320
321

    cudagraph_sizes: dict[int, int] = {}
    for i in range(1, capture_sizes[-1] + 1):
        for x in capture_sizes:
            if i <= x:
                cudagraph_sizes[i] = x
                break
322

323
324
325
326
327
328
329
330
331
    uniform_decode_cudagraph_sizes: dict[int, int] = {}
    if uniform_decode_cudagraph:
        max_num_tokens = max_num_reqs * uniform_decode_query_len
        uniform_decode_cudagraph_sizes = {
            k: v
            for k, v in cudagraph_sizes.items()
            if v <= max_num_tokens and v >= uniform_decode_query_len
        }
    return cudagraph_sizes, uniform_decode_cudagraph_sizes
332
333
334
335
336
337


def capture_graphs(
    cudagraph_sizes: dict[int, int],
    device: torch.device,
    capture_fn: Callable,
338
339
    capture_cudagraph_mode: CUDAGraphMode,
    desc: str = "Capturing CUDA graphs",
340
341
342
343
344
    **capture_kwargs,
) -> None:
    # Capture larger graphs first.
    sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
    if is_global_first_rank():
345
        sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
346
347
348

    with graph_capture(device=device):
        for size in sizes_to_capture:
349
            capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
350
351
352
353
354
355
356
357
358
359


def prepare_inputs_to_capture(
    num_reqs: int,
    num_tokens: int,
    input_buffers: InputBuffers,
    block_tables: BlockTables,
    attn_metadata_builders: list[AttentionMetadataBuilder],
    max_model_len: int,
    kv_cache_config: KVCacheConfig,
360
    uniform_decode_query_len: int = 0,
361
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
362
363
364
365
    if uniform_decode_query_len > 0:
        num_tokens_per_req = uniform_decode_query_len
    else:
        num_tokens_per_req = num_tokens // num_reqs
366
367
368
369
370
371
372

    query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
    query_start_loc_np[-1] = num_tokens
    query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
    input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
    input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
    query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
373

374
    # HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
375
    # rather than max_model_len.
376
    input_buffers.seq_lens[:num_reqs] = num_tokens
377
378
    input_buffers.seq_lens[num_reqs:] = 0

379
380
    input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
    input_buffers.dcp_local_seq_lens[num_reqs:] = 0
381

382
383
    input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
    slot_mappings = block_tables.slot_mappings[:, :num_tokens]
384
385
386
    slot_mappings_by_layer = build_slot_mappings_by_layer(
        slot_mappings, kv_cache_config
    )
387
388
389
390
391

    attn_metadata = build_attn_metadata(
        attn_metadata_builders=attn_metadata_builders,
        num_reqs=num_reqs,
        num_tokens=num_tokens,
392
393
        query_start_loc_gpu=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
394
        max_query_len=num_tokens_per_req,
395
        seq_lens=input_buffers.seq_lens,
396
        max_seq_len=max_model_len,
397
398
399
        block_tables=input_block_tables,
        slot_mappings=slot_mappings,
        kv_cache_config=kv_cache_config,
400
        dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
401
    )
402
    return attn_metadata, slot_mappings_by_layer