gpu_ubatch_wrapper.py 19.6 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import threading
5
from collections.abc import Callable
6
from dataclasses import dataclass
7
from typing import Any
8
9
10

import torch

11
import vllm.envs as envs
12
13
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.config import CUDAGraphMode, VllmConfig
14
from vllm.distributed import get_ep_group
15
16
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import (
17
    DPMetadata,
18
19
20
21
    create_forward_context,
    get_forward_context,
    override_forward_context,
)
22
from vllm.logger import init_logger
23
from vllm.model_executor.offloader.base import get_offloader
24
25
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
26
from vllm.utils.deep_gemm import set_num_sms as deep_gemm_set_num_sms
27
from vllm.utils.import_utils import has_deep_gemm
28
from vllm.utils.platform_utils import num_compute_units
29
30
31
32
33
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts

logger = init_logger(__name__)


34
@dataclass
35
36
37
38
class UbatchMetadata:
    context: UBatchContext
    input_ids: torch.Tensor
    positions: torch.Tensor
39
40
    inputs_embeds: torch.Tensor | None
    intermediate_tensors: IntermediateTensors | None
41
42
43
    num_tokens: int


44
@dataclass
45
46
47
class CUDAGraphMetaData:
    cudagraph: torch.cuda.CUDAGraph
    ubatch_metadata: UbatchMetadata
48
    outputs: Any | None = None
49
50


51
class SMControlContextManager:
52
53
54
55
56
57
    def __init__(
        self,
        comm_sms: int,
        set_comm_sms: Callable[[int], None],
        set_compute_sms: Callable[[int], None],
    ):
58
        """
59
        Context manager for controlling SM (Streaming Multiprocessor)
60
61
62
63
64
65
        allocation. Upon entering the context, it sets the number of SMs
        allocated for communication and computation to comm_sms and
        total_sms - comm_sms respectively. Upon exiting, it restores the
        allocation to use all available SMs (i.e. total_sms).

        Args:
66
            comm_sms (int): The number of SMs to allocate for communication.
67
                (The remainder will be used for computation.)
68
            set_comm_sms (Callable[[int], None]):
69
                A function that sets the number of SMs for communication.
70
            set_compute_sms (Callable[[int], None]):
71
72
73
                A function that sets the number of SMs for computation.
        """

74
        assert current_platform.is_cuda(), (
75
            "SM control is currently only supported on CUDA"
76
        )
77
78
        device = torch.accelerator.current_device_index()
        total_sms = num_compute_units(device)
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

        assert comm_sms < total_sms
        self.total_sms = total_sms
        self.compute_sms = total_sms - comm_sms
        self.comm_sms = comm_sms
        self.set_comm_sms = set_comm_sms
        self.set_compute_sms = set_compute_sms

    def __enter__(self):
        self.set_comm_sms(self.comm_sms)
        self.set_compute_sms(self.compute_sms)

    def __exit__(self, exc_type, exc_value, traceback):
        self.set_comm_sms(self.total_sms)
        self.set_compute_sms(self.total_sms)


96
class UBatchWrapper:
97
98
99
100
101
102
103
    def __init__(
        self,
        runnable: Callable,
        vllm_config: VllmConfig,
        runtime_mode: CUDAGraphMode,
        device: torch.cuda.device,
    ):
104
105
106
107
        self.runnable = runnable
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.comm_stream = torch.cuda.Stream(device=device)
108
109
110
111
        # Ubatch threads plus the main thread
        self.ready_barrier = threading.Barrier(
            self.vllm_config.parallel_config.num_ubatches + 1
        )
112
113
114
115
116
117

        self.cudagraphs: dict[int, CUDAGraphMetaData] = {}

        self.cudagraph_wrapper = None
        if runtime_mode is not CUDAGraphMode.NONE:
            self.cudagraph_wrapper = CUDAGraphWrapper(
118
119
                runnable, vllm_config, runtime_mode=runtime_mode
            )
120

121
        self.sm_control = self._create_sm_control_context(vllm_config)
122
        self.device = device
123
124
        self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
        self._runnable_str = str(runnable) if self.is_debugging_mode else None
125

126
127
128
129
130
131
132
133
134
135
136
    @property
    def graph_pool(self):
        if self.cudagraph_wrapper is not None:
            return self.cudagraph_wrapper.graph_pool
        return None

    def clear_graphs(self) -> None:
        self.cudagraphs.clear()
        if self.cudagraph_wrapper is not None:
            self.cudagraph_wrapper.clear_graphs()

137
138
    @staticmethod
    def _create_sm_control_context(vllm_config: VllmConfig):
139
        comm_sms: int = envs.VLLM_DBO_COMM_SMS
140
141
142
143
144

        set_comm_sms = lambda sms: None
        if vllm_config.parallel_config.enable_expert_parallel:
            # Currently only DeepEP highthroughput supports SM control so this
            # only affects that case.
145
146
147
148
149
150
151
152
153
154
155
156
            ep_group = get_ep_group()
            device_communicator = ep_group.device_communicator
            all2all_manager = None
            if device_communicator is not None:
                all2all_manager = device_communicator.all2all_manager

            if all2all_manager is not None:
                max_sms_used = all2all_manager.max_sms_used()
                if max_sms_used is not None:
                    comm_sms = min(comm_sms, max_sms_used)

            if comm_sms > 0 and all2all_manager is not None:
157
158
159
160
161
                set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)

        # TODO(lucas): support other kernels besides DeepGEMM
        set_compute_sms = lambda sms: None
        if has_deep_gemm() and comm_sms > 0:
162
            set_compute_sms = lambda sms: deep_gemm_set_num_sms(sms)
163

164
165
166
167
168
        return SMControlContextManager(
            comm_sms=comm_sms,
            set_comm_sms=set_comm_sms,
            set_compute_sms=set_compute_sms,
        )
169

170
171
172
173
    def __getattr__(self, key: str):
        # allow accessing the attributes of the runnable.
        if hasattr(self.runnable, key):
            return getattr(self.runnable, key)
174
175
176
177
178
179
        if self.is_debugging_mode:
            raise AttributeError(
                f"Attribute {key} not exists in the runnable of "
                f"cudagraph wrapper: {self._runnable_str}"
            )
        raise AttributeError
180
181
182
183
184
185
186
187
188
189
190
191
192
193

    def unwrap(self) -> Callable:
        # in case we need to access the original runnable.
        return self.runnable

    def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
        """
        Capture a cudagraph for a microbatched run.

        The logic here is somewhat complicated because we need to make sure that
        each of the ubatch threads initialize the cuda context before we start
        the graph capture.

        The flow is as follows:
194
        1. The main thread starts up each ubatch thread. Each thread will
195
196
197
        initialize its cuda context (torch.cuda.current_blas_handle())
        before going to sleep upon entering the ubatch_context.

198
        2. The main thread starts the graph capture and wakes up the first
199
200
        ubatch thread.

201
        3. Each ubatch thread runs the model to completion and returns the
202
203
204
205
206
207
208
209
        completed output tensors back to the main thread.

        4. The main thread stores the captured cudagraph along with its metadata
        and returns
        """

        @torch.inference_mode()
        def _capture_ubatch_thread(results, ubatch_metadata):
210
            torch.accelerator.set_device_index(self.device)
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
            ubatch_context = ubatch_metadata.context
            with torch.cuda.stream(ubatch_context.compute_stream):
                _ = torch.cuda.current_blas_handle()
            with torch.cuda.stream(ubatch_context.comm_stream):
                _ = torch.cuda.current_blas_handle()
            with ubatch_context:
                model_output = model(
                    input_ids=ubatch_metadata.input_ids,
                    positions=ubatch_metadata.positions,
                    intermediate_tensors=ubatch_metadata.intermediate_tensors,
                    inputs_embeds=ubatch_metadata.inputs_embeds,
                )

            results.append((ubatch_metadata.context.id, model_output))

        results: list[tuple[int, torch.Tensor]] = []
        compute_stream = ubatch_metadata[0].context.compute_stream
228
        num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens
229
230
231
232
233
234

        # Ubatches will manually manage the forward context, so we override
        # it to None here so we can have it restored correctly later
        with override_forward_context(None):
            ubatch_threads = []
            for metadata in ubatch_metadata:
235
236
237
238
239
240
241
                thread = threading.Thread(
                    target=_capture_ubatch_thread,
                    args=(
                        results,
                        metadata,
                    ),
                )
242
243
244
245
246
                ubatch_threads.append(thread)
                thread.start()
            self.ready_barrier.wait()  # Wait for both threads to be ready

            # Capture the cudagraph
247
248
249
250
            cudagraph_metadata = CUDAGraphMetaData(
                cudagraph=torch.cuda.CUDAGraph(),
                ubatch_metadata=ubatch_metadata,
            )
251
252
253
254
            if self.graph_pool is not None:
                set_graph_pool_id(self.graph_pool)
            else:
                set_graph_pool_id(current_platform.graph_pool_handle())
255
256
257
258
259

            # Sync offloader's copy stream before capture.
            # Ensure any pre-capture prefetches from offloader are complete.
            get_offloader().sync_prev_onload()

260
261
262
263
264
            with torch.cuda.graph(
                cudagraph_metadata.cudagraph,
                stream=compute_stream,
                pool=self.graph_pool,
            ):
265
266
267
268
269
270
                ubatch_metadata[0].context.cpu_wait_event.set()
                for thread in ubatch_threads:
                    thread.join()
                sorted_results = [value for position, value in sorted(results)]
                result = torch.cat(sorted_results, dim=0)
                cudagraph_metadata.outputs = result
271
272
273
274
                # Join offloader's copy stream after forward to avoid unjoined
                # stream error. The last layer's start_prefetch forks copy_stream,
                # but wait_prefetch only happens in the next forward pass.
                get_offloader().join_after_forward()
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
            self.cudagraphs[num_tokens] = cudagraph_metadata
        return cudagraph_metadata.outputs

    def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
        @torch.inference_mode()
        def _ubatch_thread(results, model, ubatch_metadata):
            with ubatch_metadata.context:
                model_output = model(
                    input_ids=ubatch_metadata.input_ids,
                    positions=ubatch_metadata.positions,
                    intermediate_tensors=ubatch_metadata.intermediate_tensors,
                    inputs_embeds=ubatch_metadata.inputs_embeds,
                )
            results.append((ubatch_metadata.context.id, model_output))

        results: list[tuple[int, torch.Tensor]] = []

        # Ubatch threads will manually manage the forward context, so we
        # override it to None here so we can have it restored correctly
        # after both threads have finished
        with override_forward_context(None):
            ubatch_threads = []
            for metadata in ubatch_metadata:
298
299
300
301
302
303
304
305
                thread = threading.Thread(
                    target=_ubatch_thread,
                    args=(
                        results,
                        model,
                        metadata,
                    ),
                )
306
307
308
309
310
311
312
313
314
315
                ubatch_threads.append(thread)
                thread.start()
            self.ready_barrier.wait()  # Wait for both threads to be ready
            ubatch_metadata[0].context.cpu_wait_event.set()
            for thread in ubatch_threads:
                thread.join()
        sorted_results = [value for position, value in sorted(results)]
        result = torch.cat(sorted_results, dim=0)
        return result

316
317
318
319
    def _make_ubatch_metadata(
        self,
        ubatch_slices,
        attn_metadata,
320
        slot_mapping,
321
322
323
324
325
326
327
328
329
        input_ids,
        positions,
        inputs_embeds,
        intermediate_tensors,
        compute_stream,
        dp_metadata,
        batch_descriptor,
        cudagraph_runtime_mode,
    ) -> list[UbatchMetadata]:
330
331
        # Create one forward context per ubatch
        forward_contexts = []
332
333
334
        # slot_mapping can be None, an empty dict (from create_forward_context
        # converting None to {}), or a list of dicts (one per ubatch)
        has_slot_mapping = slot_mapping and isinstance(slot_mapping, list)
335
336
337
338
339
        for i, ubatch_slice in enumerate(ubatch_slices):
            forward_contexts.append(
                create_forward_context(
                    attn_metadata[i] if attn_metadata is not None else None,
                    self.vllm_config,
340
                    dp_metadata=dp_metadata[i],
341
                    batch_descriptor=batch_descriptor,
342
                    cudagraph_runtime_mode=cudagraph_runtime_mode,
343
                    slot_mapping=slot_mapping[i] if has_slot_mapping else None,
344
345
                )
            )
346
347
348
349
350
351

        ubatch_ctxs = make_ubatch_contexts(
            num_micro_batches=len(ubatch_slices),
            comm_stream=self.comm_stream,
            compute_stream=compute_stream,
            forward_contexts=forward_contexts,
352
353
            ready_barrier=self.ready_barrier,
        )
354
355
356

        ubatch_metadata: list[UbatchMetadata] = []
        for i, ubatch_slice in enumerate(ubatch_slices):
357
358
359
360
361
362
363
364
365
366
367
368
            (
                sliced_input_ids,
                sliced_positions,
                sliced_inputs_embeds,
                sliced_intermediate_tensors,
            ) = self._slice_model_inputs(
                ubatch_slice.token_slice,
                input_ids,
                positions,
                inputs_embeds,
                intermediate_tensors,
            )
369
370
371
372
373
374
375
            ubatch_metadata.append(
                UbatchMetadata(
                    context=ubatch_ctxs[i],
                    input_ids=sliced_input_ids,
                    positions=sliced_positions,
                    inputs_embeds=sliced_inputs_embeds,
                    intermediate_tensors=sliced_intermediate_tensors,
376
377
378
379
                    num_tokens=ubatch_slice.token_slice.stop
                    - ubatch_slice.token_slice.start,
                )
            )
380
381
382

        return ubatch_metadata

383
384
385
386
387
388
389
390
    def _slice_model_inputs(
        self,
        tokens_slice: slice,
        input_ids,
        positions,
        inputs_embeds,
        intermediate_tensors,
    ):
391
        sliced_input_ids = input_ids[tokens_slice] if input_ids is not None else None
392
393
394
395
396
397
        # if we are using mrope. Mrope adds an additional dimension to the
        # positions tensor
        if positions.ndim == 2:
            sliced_positions = positions[:, tokens_slice]
        else:
            sliced_positions = positions[tokens_slice]
398
399
400
        sliced_inputs_embeds = (
            inputs_embeds[tokens_slice] if inputs_embeds is not None else None
        )
401
        sliced_intermediate_tensors = (
402
403
404
            intermediate_tensors[tokens_slice]
            if intermediate_tensors is not None
            else None
405
406
407
408
409
410
411
412
        )

        return (
            sliced_input_ids,
            sliced_positions,
            sliced_inputs_embeds,
            sliced_intermediate_tensors,
        )
413
414
415
416
417
418
419
420
421

    def __call__(self, *args, **kwargs):
        forward_context = get_forward_context()
        batch_descriptor = forward_context.batch_descriptor
        ubatch_slices = forward_context.ubatch_slices
        cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode

        # If there's no ubatching, just run the runnable object
        if ubatch_slices is None:
422
423
424
425
426
427
428
429
430
431
432
            # This is to account for the case where ubatching was aborted.
            # When we capture full graphs we only capture one graph per shape,
            # meaning that if we have a ubatched  cudagraph for the current
            # num_tokens, we don't have a non-ubatched one. Without this
            # check, the cudagraph wrapper will try to capture a cudagraph
            # for this shape during a normal run.
            if cudagraph_runtime_mode is CUDAGraphMode.FULL:
                assert batch_descriptor is not None
                if batch_descriptor.num_tokens in self.cudagraphs:
                    cudagraph_runtime_mode = CUDAGraphMode.NONE

433
            if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE):
434
435
436
437
438
439
                return self.runnable(*args, **kwargs)
            else:
                assert self.cudagraph_wrapper is not None
                return self.cudagraph_wrapper(*args, **kwargs)

        attn_metadata = forward_context.attn_metadata
440
        slot_mapping = forward_context.slot_mapping
441
        num_tokens = sum(ubatch_slice.num_tokens for ubatch_slice in ubatch_slices)
442
443
444
445
        input_ids = kwargs["input_ids"]
        positions = kwargs["positions"]
        intermediate_tensors = kwargs["intermediate_tensors"]
        inputs_embeds = kwargs["inputs_embeds"]
446
447
448
449
450
451
        compute_stream = torch.cuda.current_stream()

        dp_metadata = forward_context.dp_metadata

        # We shouldn't be here unless we are running with multiple DP ranks
        assert dp_metadata is not None
452
453
454
455
456
457
458
459
460
461
462
463
464
        ubatch_dp_metadata = []
        for ubatch_slice in ubatch_slices:
            dp_size = self.vllm_config.parallel_config.data_parallel_size
            ubatch_num_tokens_across_dp = torch.tensor(
                [ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32
            )
            ubatch_dp_metadata.append(
                DPMetadata.make(
                    self.vllm_config.parallel_config,
                    ubatch_slice.num_tokens,
                    ubatch_num_tokens_across_dp,
                )
            )
465

466
467
468
469
        if (
            num_tokens not in self.cudagraphs
            and cudagraph_runtime_mode is CUDAGraphMode.FULL
        ):
470
471
472
            ubatch_metadata = self._make_ubatch_metadata(
                ubatch_slices=ubatch_slices,
                attn_metadata=attn_metadata,
473
                slot_mapping=slot_mapping,
474
475
476
477
478
                input_ids=input_ids,
                positions=positions,
                intermediate_tensors=intermediate_tensors,
                inputs_embeds=inputs_embeds,
                compute_stream=compute_stream,
479
                dp_metadata=ubatch_dp_metadata,
480
                batch_descriptor=batch_descriptor,
481
482
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
            )
483
            with self.sm_control:
484
                return self._capture_ubatches(ubatch_metadata, self.runnable)
485
486
487
488
        elif (
            num_tokens in self.cudagraphs
            and cudagraph_runtime_mode is CUDAGraphMode.FULL
        ):
489
            cudagraph_metadata = self.cudagraphs[num_tokens]
490
491
492
            # Sync offloader before replay - ensures any external dependencies
            # from pre-capture prefetches are satisfied.
            get_offloader().sync_prev_onload()
493
494
495
496
497
498
            cudagraph_metadata.cudagraph.replay()
            return cudagraph_metadata.outputs
        else:
            ubatch_metadata = self._make_ubatch_metadata(
                ubatch_slices=ubatch_slices,
                attn_metadata=attn_metadata,
499
                slot_mapping=slot_mapping,
500
501
502
503
504
                input_ids=input_ids,
                positions=positions,
                intermediate_tensors=intermediate_tensors,
                inputs_embeds=inputs_embeds,
                compute_stream=compute_stream,
505
                dp_metadata=ubatch_dp_metadata,
506
                batch_descriptor=batch_descriptor,
507
508
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
            )
509
            with self.sm_control:
510
                return self._run_ubatches(ubatch_metadata, self.runnable)