"vllm/vscode:/vscode.git/clone" did not exist on "41183c1fe09c60cd77d683e64895c08f0d84b693"
gpu_ubatch_wrapper.py 17.5 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import threading
5
from collections.abc import Callable
6
from dataclasses import dataclass
7
from typing import Any
8
9
10

import torch

11
import vllm.envs as envs
12
13
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.config import CUDAGraphMode, VllmConfig
14
from vllm.distributed import get_ep_group
15
16
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import (
17
    DPMetadata,
18
19
20
21
    create_forward_context,
    get_forward_context,
    override_forward_context,
)
22
23
24
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
25
from vllm.utils import has_deep_gemm
26
27
28
29
30
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts

logger = init_logger(__name__)


31
@dataclass
32
33
34
35
class UbatchMetadata:
    context: UBatchContext
    input_ids: torch.Tensor
    positions: torch.Tensor
36
37
    inputs_embeds: torch.Tensor | None
    intermediate_tensors: IntermediateTensors | None
38
39
40
    num_tokens: int


41
@dataclass
42
43
44
class CUDAGraphMetaData:
    cudagraph: torch.cuda.CUDAGraph
    ubatch_metadata: UbatchMetadata
45
    outputs: Any | None = None
46
47


48
class SMControlContextManager:
49
50
51
52
53
54
    def __init__(
        self,
        comm_sms: int,
        set_comm_sms: Callable[[int], None],
        set_compute_sms: Callable[[int], None],
    ):
55
        """
56
        Context manager for controlling SM (Streaming Multiprocessor)
57
58
59
60
61
62
        allocation. Upon entering the context, it sets the number of SMs
        allocated for communication and computation to comm_sms and
        total_sms - comm_sms respectively. Upon exiting, it restores the
        allocation to use all available SMs (i.e. total_sms).

        Args:
63
            comm_sms (int): The number of SMs to allocate for communication.
64
                (The remainder will be used for computation.)
65
            set_comm_sms (Callable[[int], None]):
66
                A function that sets the number of SMs for communication.
67
            set_compute_sms (Callable[[int], None]):
68
69
70
                A function that sets the number of SMs for computation.
        """

71
        assert current_platform.is_cuda(), (
72
            "SM control is currently only supported on CUDA"
73
        )
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

        props = torch.cuda.get_device_properties(torch.cuda.current_device())
        total_sms = props.multi_processor_count

        assert comm_sms < total_sms
        self.total_sms = total_sms
        self.compute_sms = total_sms - comm_sms
        self.comm_sms = comm_sms
        self.set_comm_sms = set_comm_sms
        self.set_compute_sms = set_compute_sms

    def __enter__(self):
        self.set_comm_sms(self.comm_sms)
        self.set_compute_sms(self.compute_sms)

    def __exit__(self, exc_type, exc_value, traceback):
        self.set_comm_sms(self.total_sms)
        self.set_compute_sms(self.total_sms)


94
class UBatchWrapper:
95
96
97
98
99
100
101
    def __init__(
        self,
        runnable: Callable,
        vllm_config: VllmConfig,
        runtime_mode: CUDAGraphMode,
        device: torch.cuda.device,
    ):
102
103
104
105
106
107
108
109
110
111
112
113
114
        self.runnable = runnable
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.comm_stream = torch.cuda.Stream(device=device)
        # Two ubatch threads plus the main thread
        self.ready_barrier = threading.Barrier(3)

        self.cudagraphs: dict[int, CUDAGraphMetaData] = {}

        self.cudagraph_wrapper = None
        self.graph_pool = None
        if runtime_mode is not CUDAGraphMode.NONE:
            self.cudagraph_wrapper = CUDAGraphWrapper(
115
116
                runnable, vllm_config, runtime_mode=runtime_mode
            )
117
118
            self.graph_pool = current_platform.get_global_graph_pool()

119
        self.sm_control = self._create_sm_control_context(vllm_config)
120
        self.device = device
121
122
123
124
125
126
127
128
129

    @staticmethod
    def _create_sm_control_context(vllm_config: VllmConfig):
        comm_sms = envs.VLLM_DBO_COMM_SMS

        set_comm_sms = lambda sms: None
        if vllm_config.parallel_config.enable_expert_parallel:
            # Currently only DeepEP highthroughput supports SM control so this
            # only affects that case.
130
            all2all_manager = get_ep_group().device_communicator.all2all_manager
131
132
133
134
135
136
137
138
139
140
141

            if all2all_manager.max_sms_used() is not None:
                comm_sms = min(comm_sms, all2all_manager.max_sms_used())

            if comm_sms > 0:
                set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)

        # TODO(lucas): support other kernels besides DeepGEMM
        set_compute_sms = lambda sms: None
        if has_deep_gemm() and comm_sms > 0:
            import deep_gemm as dg
142

143
144
            set_compute_sms = lambda sms: dg.set_num_sms(sms)

145
146
147
148
149
        return SMControlContextManager(
            comm_sms=comm_sms,
            set_comm_sms=set_comm_sms,
            set_compute_sms=set_compute_sms,
        )
150

151
152
153
154
    def __getattr__(self, key: str):
        # allow accessing the attributes of the runnable.
        if hasattr(self.runnable, key):
            return getattr(self.runnable, key)
155
156
157
158
        raise AttributeError(
            f"Attribute {key} not exists in the runnable of "
            f"cudagraph wrapper: {self.runnable}"
        )
159
160
161
162
163
164
165
166
167
168
169
170
171
172

    def unwrap(self) -> Callable:
        # in case we need to access the original runnable.
        return self.runnable

    def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
        """
        Capture a cudagraph for a microbatched run.

        The logic here is somewhat complicated because we need to make sure that
        each of the ubatch threads initialize the cuda context before we start
        the graph capture.

        The flow is as follows:
173
        1. The main thread starts up each ubatch thread. Each thread will
174
175
176
        initialize its cuda context (torch.cuda.current_blas_handle())
        before going to sleep upon entering the ubatch_context.

177
        2. The main thread starts the graph capture and wakes up the first
178
179
        ubatch thread.

180
        3. Each ubatch thread runs the model to completion and returns the
181
182
183
184
185
186
187
188
        completed output tensors back to the main thread.

        4. The main thread stores the captured cudagraph along with its metadata
        and returns
        """

        @torch.inference_mode()
        def _capture_ubatch_thread(results, ubatch_metadata):
189
            torch.cuda.set_device(self.device)
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
            ubatch_context = ubatch_metadata.context
            with torch.cuda.stream(ubatch_context.compute_stream):
                _ = torch.cuda.current_blas_handle()
            with torch.cuda.stream(ubatch_context.comm_stream):
                _ = torch.cuda.current_blas_handle()
            with ubatch_context:
                model_output = model(
                    input_ids=ubatch_metadata.input_ids,
                    positions=ubatch_metadata.positions,
                    intermediate_tensors=ubatch_metadata.intermediate_tensors,
                    inputs_embeds=ubatch_metadata.inputs_embeds,
                )

            results.append((ubatch_metadata.context.id, model_output))

        results: list[tuple[int, torch.Tensor]] = []
        compute_stream = ubatch_metadata[0].context.compute_stream
207
        num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens
208
209
210
211
212
213

        # Ubatches will manually manage the forward context, so we override
        # it to None here so we can have it restored correctly later
        with override_forward_context(None):
            ubatch_threads = []
            for metadata in ubatch_metadata:
214
215
216
217
218
219
220
                thread = threading.Thread(
                    target=_capture_ubatch_thread,
                    args=(
                        results,
                        metadata,
                    ),
                )
221
222
223
224
225
                ubatch_threads.append(thread)
                thread.start()
            self.ready_barrier.wait()  # Wait for both threads to be ready

            # Capture the cudagraph
226
227
228
229
            cudagraph_metadata = CUDAGraphMetaData(
                cudagraph=torch.cuda.CUDAGraph(),
                ubatch_metadata=ubatch_metadata,
            )
230
231
232
233
            if self.graph_pool is not None:
                set_graph_pool_id(self.graph_pool)
            else:
                set_graph_pool_id(current_platform.graph_pool_handle())
234
235
236
237
238
            with torch.cuda.graph(
                cudagraph_metadata.cudagraph,
                stream=compute_stream,
                pool=self.graph_pool,
            ):
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
                ubatch_metadata[0].context.cpu_wait_event.set()
                for thread in ubatch_threads:
                    thread.join()
                sorted_results = [value for position, value in sorted(results)]
                result = torch.cat(sorted_results, dim=0)
                cudagraph_metadata.outputs = result
            self.cudagraphs[num_tokens] = cudagraph_metadata
        return cudagraph_metadata.outputs

    def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
        @torch.inference_mode()
        def _ubatch_thread(results, model, ubatch_metadata):
            with ubatch_metadata.context:
                model_output = model(
                    input_ids=ubatch_metadata.input_ids,
                    positions=ubatch_metadata.positions,
                    intermediate_tensors=ubatch_metadata.intermediate_tensors,
                    inputs_embeds=ubatch_metadata.inputs_embeds,
                )
            results.append((ubatch_metadata.context.id, model_output))

        results: list[tuple[int, torch.Tensor]] = []

        # Ubatch threads will manually manage the forward context, so we
        # override it to None here so we can have it restored correctly
        # after both threads have finished
        with override_forward_context(None):
            ubatch_threads = []
            for metadata in ubatch_metadata:
268
269
270
271
272
273
274
275
                thread = threading.Thread(
                    target=_ubatch_thread,
                    args=(
                        results,
                        model,
                        metadata,
                    ),
                )
276
277
278
279
280
281
282
283
284
285
                ubatch_threads.append(thread)
                thread.start()
            self.ready_barrier.wait()  # Wait for both threads to be ready
            ubatch_metadata[0].context.cpu_wait_event.set()
            for thread in ubatch_threads:
                thread.join()
        sorted_results = [value for position, value in sorted(results)]
        result = torch.cat(sorted_results, dim=0)
        return result

286
287
288
289
290
291
292
293
294
295
296
297
298
    def _make_ubatch_metadata(
        self,
        ubatch_slices,
        attn_metadata,
        input_ids,
        positions,
        inputs_embeds,
        intermediate_tensors,
        compute_stream,
        dp_metadata,
        batch_descriptor,
        cudagraph_runtime_mode,
    ) -> list[UbatchMetadata]:
299
300
301
302
303
304
305
306
307
        # Create one forward context per ubatch
        forward_contexts = []
        for i, ubatch_slice in enumerate(ubatch_slices):
            forward_contexts.append(
                create_forward_context(
                    attn_metadata[i] if attn_metadata is not None else None,
                    self.vllm_config,
                    dp_metadata=dp_metadata,
                    batch_descriptor=batch_descriptor,
308
309
310
                    cudagraph_runtime_mode=cudagraph_runtime_mode,
                )
            )
311
312
313
314
315
316

        ubatch_ctxs = make_ubatch_contexts(
            num_micro_batches=len(ubatch_slices),
            comm_stream=self.comm_stream,
            compute_stream=compute_stream,
            forward_contexts=forward_contexts,
317
318
            ready_barrier=self.ready_barrier,
        )
319
320
321

        ubatch_metadata: list[UbatchMetadata] = []
        for i, ubatch_slice in enumerate(ubatch_slices):
322
323
324
325
326
327
328
329
330
331
332
333
            (
                sliced_input_ids,
                sliced_positions,
                sliced_inputs_embeds,
                sliced_intermediate_tensors,
            ) = self._slice_model_inputs(
                ubatch_slice.token_slice,
                input_ids,
                positions,
                inputs_embeds,
                intermediate_tensors,
            )
334
335
336
337
338
339
340
            ubatch_metadata.append(
                UbatchMetadata(
                    context=ubatch_ctxs[i],
                    input_ids=sliced_input_ids,
                    positions=sliced_positions,
                    inputs_embeds=sliced_inputs_embeds,
                    intermediate_tensors=sliced_intermediate_tensors,
341
342
343
344
                    num_tokens=ubatch_slice.token_slice.stop
                    - ubatch_slice.token_slice.start,
                )
            )
345
346
347

        return ubatch_metadata

348
349
350
351
352
353
354
355
    def _slice_model_inputs(
        self,
        tokens_slice: slice,
        input_ids,
        positions,
        inputs_embeds,
        intermediate_tensors,
    ):
356
357
358
359
360
361
362
        sliced_input_ids = input_ids[tokens_slice]
        # if we are using mrope. Mrope adds an additional dimension to the
        # positions tensor
        if positions.ndim == 2:
            sliced_positions = positions[:, tokens_slice]
        else:
            sliced_positions = positions[tokens_slice]
363
364
365
366
367
368
369
370
371
372
373
        sliced_inputs_embeds = inputs_embeds[tokens_slice] if inputs_embeds else None
        sliced_intermediate_tensors = (
            intermediate_tensors[tokens_slice] if intermediate_tensors else None
        )

        return (
            sliced_input_ids,
            sliced_positions,
            sliced_inputs_embeds,
            sliced_intermediate_tensors,
        )
374
375
376
377
378
379
380
381
382

    def __call__(self, *args, **kwargs):
        forward_context = get_forward_context()
        batch_descriptor = forward_context.batch_descriptor
        ubatch_slices = forward_context.ubatch_slices
        cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode

        # If there's no ubatching, just run the runnable object
        if ubatch_slices is None:
383
384
385
386
387
388
389
390
391
392
393
            # This is to account for the case where ubatching was aborted.
            # When we capture full graphs we only capture one graph per shape,
            # meaning that if we have a ubatched  cudagraph for the current
            # num_tokens, we don't have a non-ubatched one. Without this
            # check, the cudagraph wrapper will try to capture a cudagraph
            # for this shape during a normal run.
            if cudagraph_runtime_mode is CUDAGraphMode.FULL:
                assert batch_descriptor is not None
                if batch_descriptor.num_tokens in self.cudagraphs:
                    cudagraph_runtime_mode = CUDAGraphMode.NONE

394
            if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE):
395
396
397
398
399
400
                return self.runnable(*args, **kwargs)
            else:
                assert self.cudagraph_wrapper is not None
                return self.cudagraph_wrapper(*args, **kwargs)

        attn_metadata = forward_context.attn_metadata
401
402
403
404
405
406
407
        num_tokens = (
            ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
        ) * 2
        input_ids = kwargs["input_ids"]
        positions = kwargs["positions"]
        intermediate_tensors = kwargs["intermediate_tensors"]
        inputs_embeds = kwargs["inputs_embeds"]
408
409
410
411
412
413
        compute_stream = torch.cuda.current_stream()

        dp_metadata = forward_context.dp_metadata

        # We shouldn't be here unless we are running with multiple DP ranks
        assert dp_metadata is not None
414
415
416
417
418
419
420
421
422
423
424
425
        num_tokens_per_ubatch = (
            ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
        )
        dp_size = self.vllm_config.parallel_config.data_parallel_size
        ubatch_num_tokens_across_dp = torch.tensor(
            [num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32
        )
        ubatch_dp_metadata = DPMetadata.make(
            self.vllm_config.parallel_config,
            num_tokens_per_ubatch,
            ubatch_num_tokens_across_dp,
        )
426

427
428
429
430
        if (
            num_tokens not in self.cudagraphs
            and cudagraph_runtime_mode is CUDAGraphMode.FULL
        ):
431
432
433
434
435
436
437
438
            ubatch_metadata = self._make_ubatch_metadata(
                ubatch_slices=ubatch_slices,
                attn_metadata=attn_metadata,
                input_ids=input_ids,
                positions=positions,
                intermediate_tensors=intermediate_tensors,
                inputs_embeds=inputs_embeds,
                compute_stream=compute_stream,
439
                dp_metadata=ubatch_dp_metadata,
440
                batch_descriptor=batch_descriptor,
441
442
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
            )
443
444
            with self.sm_control:
                return self._capture_ubatches(ubatch_metadata, self.model)
445
446
447
448
        elif (
            num_tokens in self.cudagraphs
            and cudagraph_runtime_mode is CUDAGraphMode.FULL
        ):
449
450
451
452
453
454
455
456
457
458
459
460
461
462
            cudagraph_metadata = self.cudagraphs[num_tokens]
            cudagraph_metadata.cudagraph.replay()
            return cudagraph_metadata.outputs
        else:
            ubatch_metadata = self._make_ubatch_metadata(
                ubatch_slices=ubatch_slices,
                attn_metadata=attn_metadata,
                input_ids=input_ids,
                positions=positions,
                intermediate_tensors=intermediate_tensors,
                inputs_embeds=inputs_embeds,
                compute_stream=compute_stream,
                dp_metadata=dp_metadata,
                batch_descriptor=batch_descriptor,
463
464
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
            )
465
466
            with self.sm_control:
                return self._run_ubatches(ubatch_metadata, self.model)