gpu_ubatch_wrapper.py 17.5 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import threading
5
from dataclasses import dataclass
6
7
8
9
from typing import Any, Callable, Optional

import torch

10
import vllm.envs as envs
11
12
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.config import CUDAGraphMode, VllmConfig
13
from vllm.distributed import get_ep_group
14
15
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import (
16
    DPMetadata,
17
18
19
20
    create_forward_context,
    get_forward_context,
    override_forward_context,
)
21
22
23
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
24
from vllm.utils import has_deep_gemm
25
26
27
28
29
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts

logger = init_logger(__name__)


30
@dataclass
31
32
33
34
35
36
37
38
39
class UbatchMetadata:
    context: UBatchContext
    input_ids: torch.Tensor
    positions: torch.Tensor
    inputs_embeds: Optional[torch.Tensor]
    intermediate_tensors: Optional[IntermediateTensors]
    num_tokens: int


40
@dataclass
41
42
43
44
45
46
class CUDAGraphMetaData:
    cudagraph: torch.cuda.CUDAGraph
    ubatch_metadata: UbatchMetadata
    outputs: Optional[Any] = None


47
class SMControlContextManager:
48
49
50
51
52
53
    def __init__(
        self,
        comm_sms: int,
        set_comm_sms: Callable[[int], None],
        set_compute_sms: Callable[[int], None],
    ):
54
        """
55
        Context manager for controlling SM (Streaming Multiprocessor)
56
57
58
59
60
61
        allocation. Upon entering the context, it sets the number of SMs
        allocated for communication and computation to comm_sms and
        total_sms - comm_sms respectively. Upon exiting, it restores the
        allocation to use all available SMs (i.e. total_sms).

        Args:
62
            comm_sms (int): The number of SMs to allocate for communication.
63
                (The remainder will be used for computation.)
64
            set_comm_sms (Callable[[int], None]):
65
                A function that sets the number of SMs for communication.
66
            set_compute_sms (Callable[[int], None]):
67
68
69
                A function that sets the number of SMs for computation.
        """

70
        assert current_platform.is_cuda(), (
71
            "SM control is currently only supported on CUDA"
72
        )
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

        props = torch.cuda.get_device_properties(torch.cuda.current_device())
        total_sms = props.multi_processor_count

        assert comm_sms < total_sms
        self.total_sms = total_sms
        self.compute_sms = total_sms - comm_sms
        self.comm_sms = comm_sms
        self.set_comm_sms = set_comm_sms
        self.set_compute_sms = set_compute_sms

    def __enter__(self):
        self.set_comm_sms(self.comm_sms)
        self.set_compute_sms(self.compute_sms)

    def __exit__(self, exc_type, exc_value, traceback):
        self.set_comm_sms(self.total_sms)
        self.set_compute_sms(self.total_sms)


93
class UBatchWrapper:
94
95
96
97
98
99
100
    def __init__(
        self,
        runnable: Callable,
        vllm_config: VllmConfig,
        runtime_mode: CUDAGraphMode,
        device: torch.cuda.device,
    ):
101
102
103
104
105
106
107
108
109
110
111
112
113
        self.runnable = runnable
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.comm_stream = torch.cuda.Stream(device=device)
        # Two ubatch threads plus the main thread
        self.ready_barrier = threading.Barrier(3)

        self.cudagraphs: dict[int, CUDAGraphMetaData] = {}

        self.cudagraph_wrapper = None
        self.graph_pool = None
        if runtime_mode is not CUDAGraphMode.NONE:
            self.cudagraph_wrapper = CUDAGraphWrapper(
114
115
                runnable, vllm_config, runtime_mode=runtime_mode
            )
116
117
            self.graph_pool = current_platform.get_global_graph_pool()

118
        self.sm_control = self._create_sm_control_context(vllm_config)
119
        self.device = device
120
121
122
123
124
125
126
127
128

    @staticmethod
    def _create_sm_control_context(vllm_config: VllmConfig):
        comm_sms = envs.VLLM_DBO_COMM_SMS

        set_comm_sms = lambda sms: None
        if vllm_config.parallel_config.enable_expert_parallel:
            # Currently only DeepEP highthroughput supports SM control so this
            # only affects that case.
129
            all2all_manager = get_ep_group().device_communicator.all2all_manager
130
131
132
133
134
135
136
137
138
139
140

            if all2all_manager.max_sms_used() is not None:
                comm_sms = min(comm_sms, all2all_manager.max_sms_used())

            if comm_sms > 0:
                set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)

        # TODO(lucas): support other kernels besides DeepGEMM
        set_compute_sms = lambda sms: None
        if has_deep_gemm() and comm_sms > 0:
            import deep_gemm as dg
141

142
143
            set_compute_sms = lambda sms: dg.set_num_sms(sms)

144
145
146
147
148
        return SMControlContextManager(
            comm_sms=comm_sms,
            set_comm_sms=set_comm_sms,
            set_compute_sms=set_compute_sms,
        )
149

150
151
152
153
    def __getattr__(self, key: str):
        # allow accessing the attributes of the runnable.
        if hasattr(self.runnable, key):
            return getattr(self.runnable, key)
154
155
156
157
        raise AttributeError(
            f"Attribute {key} not exists in the runnable of "
            f"cudagraph wrapper: {self.runnable}"
        )
158
159
160
161
162
163
164
165
166
167
168
169
170
171

    def unwrap(self) -> Callable:
        # in case we need to access the original runnable.
        return self.runnable

    def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
        """
        Capture a cudagraph for a microbatched run.

        The logic here is somewhat complicated because we need to make sure that
        each of the ubatch threads initialize the cuda context before we start
        the graph capture.

        The flow is as follows:
172
        1. The main thread starts up each ubatch thread. Each thread will
173
174
175
        initialize its cuda context (torch.cuda.current_blas_handle())
        before going to sleep upon entering the ubatch_context.

176
        2. The main thread starts the graph capture and wakes up the first
177
178
        ubatch thread.

179
        3. Each ubatch thread runs the model to completion and returns the
180
181
182
183
184
185
186
187
        completed output tensors back to the main thread.

        4. The main thread stores the captured cudagraph along with its metadata
        and returns
        """

        @torch.inference_mode()
        def _capture_ubatch_thread(results, ubatch_metadata):
188
            torch.cuda.set_device(self.device)
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
            ubatch_context = ubatch_metadata.context
            with torch.cuda.stream(ubatch_context.compute_stream):
                _ = torch.cuda.current_blas_handle()
            with torch.cuda.stream(ubatch_context.comm_stream):
                _ = torch.cuda.current_blas_handle()
            with ubatch_context:
                model_output = model(
                    input_ids=ubatch_metadata.input_ids,
                    positions=ubatch_metadata.positions,
                    intermediate_tensors=ubatch_metadata.intermediate_tensors,
                    inputs_embeds=ubatch_metadata.inputs_embeds,
                )

            results.append((ubatch_metadata.context.id, model_output))

        results: list[tuple[int, torch.Tensor]] = []
        compute_stream = ubatch_metadata[0].context.compute_stream
206
        num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens
207
208
209
210
211
212

        # Ubatches will manually manage the forward context, so we override
        # it to None here so we can have it restored correctly later
        with override_forward_context(None):
            ubatch_threads = []
            for metadata in ubatch_metadata:
213
214
215
216
217
218
219
                thread = threading.Thread(
                    target=_capture_ubatch_thread,
                    args=(
                        results,
                        metadata,
                    ),
                )
220
221
222
223
224
                ubatch_threads.append(thread)
                thread.start()
            self.ready_barrier.wait()  # Wait for both threads to be ready

            # Capture the cudagraph
225
226
227
228
            cudagraph_metadata = CUDAGraphMetaData(
                cudagraph=torch.cuda.CUDAGraph(),
                ubatch_metadata=ubatch_metadata,
            )
229
230
231
232
            if self.graph_pool is not None:
                set_graph_pool_id(self.graph_pool)
            else:
                set_graph_pool_id(current_platform.graph_pool_handle())
233
234
235
236
237
            with torch.cuda.graph(
                cudagraph_metadata.cudagraph,
                stream=compute_stream,
                pool=self.graph_pool,
            ):
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
                ubatch_metadata[0].context.cpu_wait_event.set()
                for thread in ubatch_threads:
                    thread.join()
                sorted_results = [value for position, value in sorted(results)]
                result = torch.cat(sorted_results, dim=0)
                cudagraph_metadata.outputs = result
            self.cudagraphs[num_tokens] = cudagraph_metadata
        return cudagraph_metadata.outputs

    def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
        @torch.inference_mode()
        def _ubatch_thread(results, model, ubatch_metadata):
            with ubatch_metadata.context:
                model_output = model(
                    input_ids=ubatch_metadata.input_ids,
                    positions=ubatch_metadata.positions,
                    intermediate_tensors=ubatch_metadata.intermediate_tensors,
                    inputs_embeds=ubatch_metadata.inputs_embeds,
                )
            results.append((ubatch_metadata.context.id, model_output))

        results: list[tuple[int, torch.Tensor]] = []

        # Ubatch threads will manually manage the forward context, so we
        # override it to None here so we can have it restored correctly
        # after both threads have finished
        with override_forward_context(None):
            ubatch_threads = []
            for metadata in ubatch_metadata:
267
268
269
270
271
272
273
274
                thread = threading.Thread(
                    target=_ubatch_thread,
                    args=(
                        results,
                        model,
                        metadata,
                    ),
                )
275
276
277
278
279
280
281
282
283
284
                ubatch_threads.append(thread)
                thread.start()
            self.ready_barrier.wait()  # Wait for both threads to be ready
            ubatch_metadata[0].context.cpu_wait_event.set()
            for thread in ubatch_threads:
                thread.join()
        sorted_results = [value for position, value in sorted(results)]
        result = torch.cat(sorted_results, dim=0)
        return result

285
286
287
288
289
290
291
292
293
294
295
296
297
    def _make_ubatch_metadata(
        self,
        ubatch_slices,
        attn_metadata,
        input_ids,
        positions,
        inputs_embeds,
        intermediate_tensors,
        compute_stream,
        dp_metadata,
        batch_descriptor,
        cudagraph_runtime_mode,
    ) -> list[UbatchMetadata]:
298
299
300
301
302
303
304
305
306
        # Create one forward context per ubatch
        forward_contexts = []
        for i, ubatch_slice in enumerate(ubatch_slices):
            forward_contexts.append(
                create_forward_context(
                    attn_metadata[i] if attn_metadata is not None else None,
                    self.vllm_config,
                    dp_metadata=dp_metadata,
                    batch_descriptor=batch_descriptor,
307
308
309
                    cudagraph_runtime_mode=cudagraph_runtime_mode,
                )
            )
310
311
312
313
314
315

        ubatch_ctxs = make_ubatch_contexts(
            num_micro_batches=len(ubatch_slices),
            comm_stream=self.comm_stream,
            compute_stream=compute_stream,
            forward_contexts=forward_contexts,
316
317
            ready_barrier=self.ready_barrier,
        )
318
319
320

        ubatch_metadata: list[UbatchMetadata] = []
        for i, ubatch_slice in enumerate(ubatch_slices):
321
322
323
324
325
326
327
328
329
330
331
332
            (
                sliced_input_ids,
                sliced_positions,
                sliced_inputs_embeds,
                sliced_intermediate_tensors,
            ) = self._slice_model_inputs(
                ubatch_slice.token_slice,
                input_ids,
                positions,
                inputs_embeds,
                intermediate_tensors,
            )
333
334
335
336
337
338
339
            ubatch_metadata.append(
                UbatchMetadata(
                    context=ubatch_ctxs[i],
                    input_ids=sliced_input_ids,
                    positions=sliced_positions,
                    inputs_embeds=sliced_inputs_embeds,
                    intermediate_tensors=sliced_intermediate_tensors,
340
341
342
343
                    num_tokens=ubatch_slice.token_slice.stop
                    - ubatch_slice.token_slice.start,
                )
            )
344
345
346

        return ubatch_metadata

347
348
349
350
351
352
353
354
    def _slice_model_inputs(
        self,
        tokens_slice: slice,
        input_ids,
        positions,
        inputs_embeds,
        intermediate_tensors,
    ):
355
356
357
358
359
360
361
        sliced_input_ids = input_ids[tokens_slice]
        # if we are using mrope. Mrope adds an additional dimension to the
        # positions tensor
        if positions.ndim == 2:
            sliced_positions = positions[:, tokens_slice]
        else:
            sliced_positions = positions[tokens_slice]
362
363
364
365
366
367
368
369
370
371
372
        sliced_inputs_embeds = inputs_embeds[tokens_slice] if inputs_embeds else None
        sliced_intermediate_tensors = (
            intermediate_tensors[tokens_slice] if intermediate_tensors else None
        )

        return (
            sliced_input_ids,
            sliced_positions,
            sliced_inputs_embeds,
            sliced_intermediate_tensors,
        )
373
374
375
376
377
378
379
380
381

    def __call__(self, *args, **kwargs):
        forward_context = get_forward_context()
        batch_descriptor = forward_context.batch_descriptor
        ubatch_slices = forward_context.ubatch_slices
        cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode

        # If there's no ubatching, just run the runnable object
        if ubatch_slices is None:
382
383
384
385
386
387
388
389
390
391
392
            # This is to account for the case where ubatching was aborted.
            # When we capture full graphs we only capture one graph per shape,
            # meaning that if we have a ubatched  cudagraph for the current
            # num_tokens, we don't have a non-ubatched one. Without this
            # check, the cudagraph wrapper will try to capture a cudagraph
            # for this shape during a normal run.
            if cudagraph_runtime_mode is CUDAGraphMode.FULL:
                assert batch_descriptor is not None
                if batch_descriptor.num_tokens in self.cudagraphs:
                    cudagraph_runtime_mode = CUDAGraphMode.NONE

393
            if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE):
394
395
396
397
398
399
                return self.runnable(*args, **kwargs)
            else:
                assert self.cudagraph_wrapper is not None
                return self.cudagraph_wrapper(*args, **kwargs)

        attn_metadata = forward_context.attn_metadata
400
401
402
403
404
405
406
        num_tokens = (
            ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
        ) * 2
        input_ids = kwargs["input_ids"]
        positions = kwargs["positions"]
        intermediate_tensors = kwargs["intermediate_tensors"]
        inputs_embeds = kwargs["inputs_embeds"]
407
408
409
410
411
412
        compute_stream = torch.cuda.current_stream()

        dp_metadata = forward_context.dp_metadata

        # We shouldn't be here unless we are running with multiple DP ranks
        assert dp_metadata is not None
413
414
415
416
417
418
419
420
421
422
423
424
        num_tokens_per_ubatch = (
            ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
        )
        dp_size = self.vllm_config.parallel_config.data_parallel_size
        ubatch_num_tokens_across_dp = torch.tensor(
            [num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32
        )
        ubatch_dp_metadata = DPMetadata.make(
            self.vllm_config.parallel_config,
            num_tokens_per_ubatch,
            ubatch_num_tokens_across_dp,
        )
425

426
427
428
429
        if (
            num_tokens not in self.cudagraphs
            and cudagraph_runtime_mode is CUDAGraphMode.FULL
        ):
430
431
432
433
434
435
436
437
            ubatch_metadata = self._make_ubatch_metadata(
                ubatch_slices=ubatch_slices,
                attn_metadata=attn_metadata,
                input_ids=input_ids,
                positions=positions,
                intermediate_tensors=intermediate_tensors,
                inputs_embeds=inputs_embeds,
                compute_stream=compute_stream,
438
                dp_metadata=ubatch_dp_metadata,
439
                batch_descriptor=batch_descriptor,
440
441
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
            )
442
443
            with self.sm_control:
                return self._capture_ubatches(ubatch_metadata, self.model)
444
445
446
447
        elif (
            num_tokens in self.cudagraphs
            and cudagraph_runtime_mode is CUDAGraphMode.FULL
        ):
448
449
450
451
452
453
454
455
456
457
458
459
460
461
            cudagraph_metadata = self.cudagraphs[num_tokens]
            cudagraph_metadata.cudagraph.replay()
            return cudagraph_metadata.outputs
        else:
            ubatch_metadata = self._make_ubatch_metadata(
                ubatch_slices=ubatch_slices,
                attn_metadata=attn_metadata,
                input_ids=input_ids,
                positions=positions,
                intermediate_tensors=intermediate_tensors,
                inputs_embeds=inputs_embeds,
                compute_stream=compute_stream,
                dp_metadata=dp_metadata,
                batch_descriptor=batch_descriptor,
462
463
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
            )
464
465
            with self.sm_control:
                return self._run_ubatches(ubatch_metadata, self.model)