cudagraph_utils.py 15 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Callable
4
from typing import Any
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11
12
13

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
14
15
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.math_utils import cdiv
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.v1.kv_cache_interface import KVCacheConfig
17
18
19
20
from vllm.v1.worker.gpu.attn_utils import (
    build_attn_metadata,
    build_slot_mappings_by_layer,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm.v1.worker.gpu.block_table import BlockTables
22
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.v1.worker.gpu.input_batch import InputBuffers
24
from vllm.v1.worker.utils import AttentionGroup
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27


class CudaGraphManager:
28
    def __init__(self, vllm_config: VllmConfig, uses_mrope: bool, device: torch.device):
Woosuk Kwon's avatar
Woosuk Kwon committed
29
        self.vllm_config = vllm_config
30
        self.scheduler_config = vllm_config.scheduler_config
31
        self.uses_mrope = uses_mrope
Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
34
        self.device = device

        self.max_model_len = vllm_config.model_config.max_model_len
35
        self.max_num_reqs = self.scheduler_config.max_num_seqs
36
        self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
37
        self.dp_size = vllm_config.parallel_config.data_parallel_size
38
39
40
41
42
43

        self.uniform_decode_query_len = 1
        spec_config = vllm_config.speculative_config
        if spec_config is not None:
            self.uniform_decode_query_len += spec_config.num_speculative_tokens

Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
        self.compilation_config = vllm_config.compilation_config
        assert self.compilation_config is not None
46
        self.cudagraph_mode = self.compilation_config.cudagraph_mode
47
48
49
50
51
52

        use_uniform_decode_cudagraph = (
            self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
            and self.cudagraph_mode.separate_routine()
        )
        self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
53
54
55
56
            self.compilation_config.cudagraph_capture_sizes,
            self.max_num_reqs,
            self.max_num_tokens,
            self.cudagraph_mode,
57
58
            self.uniform_decode_query_len,
            use_uniform_decode_cudagraph,
59
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61

        self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
62
63
64
        self.pool = None
        if self.cudagraph_mode != CUDAGraphMode.NONE:
            self.pool = torch.cuda.graph_pool_handle()
Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
67
        self.hidden_states: torch.Tensor | None = None

    def needs_capture(self) -> bool:
68
        return len(self.cudagraph_sizes) > 0
Woosuk Kwon's avatar
Woosuk Kwon committed
69
70

    def get_cudagraph_size(
71
        self, num_tokens: int, uniform_decode: bool = False
Woosuk Kwon's avatar
Woosuk Kwon committed
72
    ) -> int | None:
73
74
75
        if uniform_decode and self.uniform_decode_cudagraph_sizes:
            return self.uniform_decode_cudagraph_sizes.get(num_tokens)
        return self.cudagraph_sizes.get(num_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78

    def capture_graph(
        self,
79
        num_tokens: int,
80
        capture_cg_mode: CUDAGraphMode,
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
        model: nn.Module,
        input_buffers: InputBuffers,
83
        mrope_positions: torch.Tensor | None,
84
        inputs_embeds: torch.Tensor | None,
Woosuk Kwon's avatar
Woosuk Kwon committed
85
        block_tables: BlockTables,
86
        attn_groups: list[list[AttentionGroup]],
Woosuk Kwon's avatar
Woosuk Kwon committed
87
        kv_cache_config: KVCacheConfig,
88
89
        has_lora: bool = False,
        uniform_decode: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
90
    ) -> None:
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        # select and check capture function
        assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
            f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
        )
        if capture_cg_mode == CUDAGraphMode.PIECEWISE:
            capture_fn = self._capture_piecewise_graph
        else:
            capture_fn = self._capture_full_graph
        # prepare inputs
        if uniform_decode:
            num_reqs = min(
                cdiv(num_tokens, self.uniform_decode_query_len),
                self.max_num_reqs,
            )
        else:
            num_reqs = min(num_tokens, self.max_num_reqs)
107
        input_ids = input_buffers.input_ids[:num_tokens]
108
109
110
111
        positions = input_buffers.positions[:num_tokens]
        if self.uses_mrope:
            assert mrope_positions is not None
            positions = mrope_positions[:, :num_tokens]
112
113
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[:num_tokens]
114
        attn_metadata, slot_mappings = prepare_inputs_to_capture(
115
116
117
118
            num_reqs,
            num_tokens,
            input_buffers,
            block_tables,
119
            attn_groups,
120
121
            self.max_model_len,
            kv_cache_config,
122
123
124
            uniform_decode_query_len=(
                self.uniform_decode_query_len if uniform_decode else 0
            ),
Woosuk Kwon's avatar
Woosuk Kwon committed
125
        )
126
        num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
127
128
129
130
131

        # Warm up.
        with set_forward_context(
            attn_metadata,
            self.vllm_config,
132
            num_tokens=num_tokens,
133
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
Woosuk Kwon's avatar
Woosuk Kwon committed
134
            num_tokens_across_dp=num_tokens_across_dp,
135
            slot_mapping=slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
139
        ):
            hidden_states = model(
                input_ids=input_ids,
                positions=positions,
140
                inputs_embeds=inputs_embeds,
Woosuk Kwon's avatar
Woosuk Kwon committed
141
142
143
144
            )
            if self.hidden_states is None:
                self.hidden_states = torch.empty_like(hidden_states)

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        capture_fn(
            num_tokens=num_tokens,
            num_reqs=num_reqs,
            model=model,
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            num_tokens_across_dp=num_tokens_across_dp,
            attn_metadata=attn_metadata,
            slot_mappings=slot_mappings,
            has_lora=has_lora,
        )

    def _capture_full_graph(
        self,
        num_tokens: int,
        num_reqs: int,
        model: nn.Module,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        inputs_embeds: torch.Tensor | None,
        num_tokens_across_dp: torch.Tensor,
        attn_metadata: dict[str, Any] | None,
        slot_mappings: dict[str, torch.Tensor] | None,
        has_lora: bool = False,
    ) -> None:
        assert attn_metadata is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
172
        # Capture the graph.
173
        assert num_tokens not in self.graphs
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176
        graph = torch.cuda.CUDAGraph()
        with (
            set_forward_context(
177
178
                attn_metadata=attn_metadata,
                vllm_config=self.vllm_config,
179
                num_tokens=num_tokens,
180
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
Woosuk Kwon's avatar
Woosuk Kwon committed
181
                num_tokens_across_dp=num_tokens_across_dp,
182
                slot_mapping=slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
183
184
185
186
187
188
            ),
            torch.cuda.graph(graph, self.pool),
        ):
            hidden_states = model(
                input_ids=input_ids,
                positions=positions,
189
                inputs_embeds=inputs_embeds,
Woosuk Kwon's avatar
Woosuk Kwon committed
190
            )
191
            assert self.hidden_states is not None
192
193
            self.hidden_states[:num_tokens] = hidden_states
        self.graphs[num_tokens] = graph
Woosuk Kwon's avatar
Woosuk Kwon committed
194

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    def _capture_piecewise_graph(
        self,
        num_tokens: int,
        num_reqs: int,
        model: nn.Module,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        inputs_embeds: torch.Tensor | None,
        num_tokens_across_dp: torch.Tensor,
        attn_metadata: dict[str, Any] | None,
        slot_mappings: dict[str, torch.Tensor] | None,
        has_lora: bool = False,
    ) -> None:
        # create batch descriptor for piecewise cudagraph dispatch key
        batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)

        # Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
        with set_forward_context(
            attn_metadata=None,  # piecewise no need attn_metadata
            vllm_config=self.vllm_config,
            num_tokens=num_tokens,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            num_tokens_across_dp=num_tokens_across_dp,
            batch_descriptor=batch_descriptor,
            slot_mapping=slot_mappings,
        ):
221
            model(
222
223
224
225
226
                input_ids=input_ids,
                positions=positions,
                inputs_embeds=inputs_embeds,
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
227
228
229
230
231
    @torch.inference_mode()
    def capture(
        self,
        model: nn.Module,
        input_buffers: InputBuffers,
232
        mrope_positions: torch.Tensor | None,
233
        inputs_embeds: torch.Tensor | None,
Woosuk Kwon's avatar
Woosuk Kwon committed
234
        block_tables: BlockTables,
235
        attn_groups: list[list[AttentionGroup]],
Woosuk Kwon's avatar
Woosuk Kwon committed
236
        kv_cache_config: KVCacheConfig,
237
        has_lora: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
238
    ) -> None:
239
240
241
        common_kwargs = dict(
            device=self.device,
            capture_fn=self.capture_graph,
242
243
            model=model,
            input_buffers=input_buffers,
244
            mrope_positions=mrope_positions,
245
            inputs_embeds=inputs_embeds,
246
            block_tables=block_tables,
247
            attn_groups=attn_groups,
248
            kv_cache_config=kv_cache_config,
249
            has_lora=has_lora,
250
251
        )

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        # Phase 1: Capture for mixed prefill-decode batches if needed.
        mixed_mode = self.cudagraph_mode.mixed_mode()
        if mixed_mode != CUDAGraphMode.NONE:
            capture_graphs(
                cudagraph_sizes=self.cudagraph_sizes,
                capture_cudagraph_mode=mixed_mode,
                desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
                uniform_decode=False,
                **common_kwargs,
            )

        # Phase 2: Capture FULL graphs for uniform decode batches if needed.
        # This is only needed if we use a separate routine for decode batches
        # and the decode_mode is FULL.
        if self.uniform_decode_cudagraph_sizes:
            capture_graphs(
                cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
                capture_cudagraph_mode=CUDAGraphMode.FULL,
                desc="Capturing CUDA graphs (decode, FULL)",
                uniform_decode=True,
                **common_kwargs,
            )

    def get_cudagraph_runtime_mode(
        self, num_reqs: int, num_tokens: int, max_query_len: int
    ) -> tuple[CUDAGraphMode, int | None]:
        is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
            num_tokens == max_query_len * num_reqs
        )

        cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
        if cudagraph_size is None:
            cudagraph_mode = CUDAGraphMode.NONE
        elif is_uniform_decode:
            cudagraph_mode = self.cudagraph_mode.decode_mode()
        else:
            cudagraph_mode = self.cudagraph_mode.mixed_mode()
289
290
291
292
293
294
295
296
297
298

        if (
            cudagraph_mode == CUDAGraphMode.FULL
            and cudagraph_size is not None
            and cudagraph_size not in self.graphs
        ):
            # If graph wasn't captured yet, fall back to eager.
            # This might happen when the dummy run is called before capture.
            cudagraph_mode = CUDAGraphMode.NONE
            cudagraph_size = None
299
300
301
302
        return cudagraph_mode, cudagraph_size

    def run_fullgraph(self, num_tokens: int) -> torch.Tensor:
        assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
303
        self.graphs[num_tokens].replay()
Woosuk Kwon's avatar
Woosuk Kwon committed
304
        assert self.hidden_states is not None
305
306
307
308
309
310
311
312
        return self.hidden_states[:num_tokens]


def get_cudagraph_sizes(
    capture_sizes: list[int] | None,
    max_num_reqs: int,
    max_num_tokens: int,
    cudagraph_mode: CUDAGraphMode,
313
314
315
316
317
318
    uniform_decode_query_len: int = 1,
    uniform_decode_cudagraph: bool = False,
) -> tuple[dict[int, int], dict[int, int]]:
    # Support both FULL and PIECEWISE cudagraph modes
    if cudagraph_mode == CUDAGraphMode.NONE:
        return {}, {}
319
    if not capture_sizes:
320
        return {}, {}
321
322
323

    capture_sizes = sorted(capture_sizes)
    if not capture_sizes:
324
        return {}, {}
325
326
327
328
329
330
331

    cudagraph_sizes: dict[int, int] = {}
    for i in range(1, capture_sizes[-1] + 1):
        for x in capture_sizes:
            if i <= x:
                cudagraph_sizes[i] = x
                break
332

333
334
335
336
337
338
339
340
341
    uniform_decode_cudagraph_sizes: dict[int, int] = {}
    if uniform_decode_cudagraph:
        max_num_tokens = max_num_reqs * uniform_decode_query_len
        uniform_decode_cudagraph_sizes = {
            k: v
            for k, v in cudagraph_sizes.items()
            if v <= max_num_tokens and v >= uniform_decode_query_len
        }
    return cudagraph_sizes, uniform_decode_cudagraph_sizes
342
343
344
345
346
347


def capture_graphs(
    cudagraph_sizes: dict[int, int],
    device: torch.device,
    capture_fn: Callable,
348
349
    capture_cudagraph_mode: CUDAGraphMode,
    desc: str = "Capturing CUDA graphs",
350
351
352
353
354
    **capture_kwargs,
) -> None:
    # Capture larger graphs first.
    sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
    if is_global_first_rank():
355
        sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
356
357
358

    with graph_capture(device=device):
        for size in sizes_to_capture:
359
            capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
360
361
362
363
364
365
366


def prepare_inputs_to_capture(
    num_reqs: int,
    num_tokens: int,
    input_buffers: InputBuffers,
    block_tables: BlockTables,
367
    attn_groups: list[list[AttentionGroup]],
368
369
    max_model_len: int,
    kv_cache_config: KVCacheConfig,
370
    uniform_decode_query_len: int = 0,
371
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
372
373
374
375
    if uniform_decode_query_len > 0:
        num_tokens_per_req = uniform_decode_query_len
    else:
        num_tokens_per_req = num_tokens // num_reqs
376
377
378
379
380
381
382

    query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
    query_start_loc_np[-1] = num_tokens
    query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
    input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
    input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
    query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
383

384
    # HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
385
    # rather than max_model_len.
386
    input_buffers.seq_lens[:num_reqs] = num_tokens
387
388
    input_buffers.seq_lens[num_reqs:] = 0

389
390
    input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
    input_buffers.dcp_local_seq_lens[num_reqs:] = 0
391

392
393
    input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
    slot_mappings = block_tables.slot_mappings[:, :num_tokens]
394
395
396
    slot_mappings_by_layer = build_slot_mappings_by_layer(
        slot_mappings, kv_cache_config
    )
397
398

    attn_metadata = build_attn_metadata(
399
        attn_groups=attn_groups,
400
401
        num_reqs=num_reqs,
        num_tokens=num_tokens,
402
403
        query_start_loc_gpu=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
404
        max_query_len=num_tokens_per_req,
405
        seq_lens=input_buffers.seq_lens,
406
        max_seq_len=max_model_len,
407
408
409
        block_tables=input_block_tables,
        slot_mappings=slot_mappings,
        kv_cache_config=kv_cache_config,
410
        dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
411
    )
412
    return attn_metadata, slot_mappings_by_layer