all2all.py 19.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any
4

5
import torch
6
import torch.distributed as dist
7

8
import vllm.envs as envs
9
from vllm.distributed import get_dp_group, get_ep_group
10
from vllm.forward_context import get_forward_context
11
from vllm.logger import init_logger
12
from vllm.utils.flashinfer import has_flashinfer_all2all
13
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
14

15
from .base_device_communicator import All2AllManagerBase, Cache
16

17
if has_flashinfer_all2all():
18
19
20
21
22
    from flashinfer.comm import Mapping  # type: ignore[import-not-found]
    from flashinfer.comm.mnnvl import MnnvlConfig  # type: ignore[import-not-found]
    from flashinfer.comm.trtllm_alltoall import (
        MnnvlMoe,  # type: ignore[import-not-found]
    )
23

24
logger = init_logger(__name__)
25
26


27
class NaiveAll2AllManager(All2AllManagerBase):
28
29
30
31
32
33
34
    """
    A naive implementation of all2all communication.
    It uses all-reduce under the hood, which is not
    efficient at all. The main purpose is for testing and
    debugging.
    """

35
36
    def __init__(self, cpu_group):
        super().__init__(cpu_group)
37

38
39
40
41
42
43
44
45
46
47
    def naive_multicast(
        self,
        x: torch.Tensor,
        cu_tokens_across_sp_cpu: torch.Tensor,
        is_sequence_parallel: bool,
    ) -> torch.Tensor:
        assert len(x.shape) == 2
        buffer = torch.empty(
            (cu_tokens_across_sp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype
        )
48

49
        rank = self.rank if is_sequence_parallel else self.dp_rank
50
        world_size = self.world_size if is_sequence_parallel else self.dp_world_size
51
52
53

        start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
        end = cu_tokens_across_sp_cpu[rank]
54
        buffer[start:end, :].copy_(x)
55
56
57
58
        for idx in range(world_size):
            start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
            end = cu_tokens_across_sp_cpu[idx]
            get_ep_group().broadcast(buffer[start:end, :], idx)
59
60
61

        return buffer

zhuwenwen's avatar
zhuwenwen committed
62
    def dispatch(
63
64
65
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
66
        is_sequence_parallel: bool = False,
67
        extra_tensors: list[torch.Tensor] | None = None,
68
    ) -> tuple[torch.Tensor, torch.Tensor]:
69
70
71
72
        if extra_tensors is not None:
            raise NotImplementedError(
                "extra_tensors is not supported for NaiveAll2AllManager"
            )
73
74
        sp_size = self.tp_group.world_size if is_sequence_parallel else 1
        dp_metadata = get_forward_context().dp_metadata
75
        assert dp_metadata is not None
76
77
        cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

78
79
80
81
82
83
        hidden_states = self.naive_multicast(
            hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
        )
        router_logits = self.naive_multicast(
            router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
        )
84

85
86
        return hidden_states, router_logits

87
88
89
    def combine(
        self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
90
91
92
        ep_rank = self.rank if is_sequence_parallel else self.dp_rank

        dp_metadata = get_forward_context().dp_metadata
93
        assert dp_metadata is not None
94
95
96
97
98
99
100
101
        sp_size = self.tp_group.world_size if is_sequence_parallel else 1
        cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

        start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
        end = cu_tokens_across_sp_cpu[ep_rank]

        all_hidden_states = get_ep_group().all_reduce(hidden_states)
        hidden_states = all_hidden_states[start:end, :]
102
103
104
105
        return hidden_states

    def destroy(self):
        pass
106
107


108
109
110
111
112
113
114
115
116
class AgRsAll2AllManager(All2AllManagerBase):
    """
    An implementation of all2all communication based on
    all-gather (dispatch) and reduce-scatter (combine).
    """

    def __init__(self, cpu_group):
        super().__init__(cpu_group)

zhuwenwen's avatar
zhuwenwen committed
117
    def dispatch(
118
119
120
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
121
        is_sequence_parallel: bool = False,
122
123
124
125
126
        extra_tensors: list[torch.Tensor] | None = None,
    ) -> (
        tuple[torch.Tensor, torch.Tensor]
        | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
    ):
127
128
129
        """
        Gather hidden_states and router_logits from all dp ranks.
        """
130
131
132
133
        dp_metadata = get_forward_context().dp_metadata
        assert dp_metadata is not None
        sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
        assert sizes is not None
134
135
        dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
        assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
136
137
138
139
140
141
142

        tensors_to_gather = [hidden_states, router_logits]
        if extra_tensors is not None:
            tensors_to_gather.extend(extra_tensors)

        gathered_tensors = dist_group.all_gatherv(
            tensors_to_gather,
143
144
145
            dim=0,
            sizes=sizes,
        )
146
147
148
149

        if extra_tensors is not None:
            return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
        return gathered_tensors[0], gathered_tensors[1]
150

151
152
153
    def combine(
        self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
154
155
156
        """
        Reduce-scatter hidden_states across all dp ranks.
        """
157
158
159
160
        dp_metadata = get_forward_context().dp_metadata
        assert dp_metadata is not None
        sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
        assert sizes is not None
161
162

        dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
163
        hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
164
165
166
167
168
169
        return hidden_states

    def destroy(self):
        pass


170
171
172
173
174
175
class PPLXAll2AllManager(All2AllManagerBase):
    """
    All2All communication based on PPLX kernels.
    """

    def __init__(self, cpu_group):
176
        assert has_pplx(), (
177
178
179
            "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
            " to install pplx_kernels."
        )
180
181
182
183
184
        super().__init__(cpu_group)

        if self.internode:
            # inter-node communication needs nvshmem,
            # intra-node communication uses p2p mapping directly
185
            from pplx_kernels.nvshmem import (  # type: ignore[import-not-found]
186
187
188
189
190
                nvshmem_alloc_empty_unique_id,
                nvshmem_get_unique_id,
                nvshmem_init,
            )

191
            logger.debug(
192
193
194
195
196
197
198
199
200
201
202
203
204
205
                "Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
                self.rank,
                self.world_size,
            )
            uid = (
                nvshmem_get_unique_id()
                if self.rank == 0
                else nvshmem_alloc_empty_unique_id()
            )
            dist.broadcast(
                uid,
                src=dist.get_process_group_ranks(self.cpu_group)[0],
                group=self.cpu_group,
            )
206
207
208
209
210
211
            logger.debug("PPLX NVSHMEM UID = %s", uid)
            nvshmem_init(uid, self.rank, self.world_size)

        self.handle_cache = Cache()

    def get_handle(self, kwargs):
212
        import pplx_kernels as pplx  # type: ignore[import-not-found]
213

214
        return self.handle_cache.get_or_create(
215
216
217
            kwargs,
            pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
        )
218

zhuwenwen's avatar
zhuwenwen committed
219
    def dispatch(
220
221
222
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
223
        is_sequence_parallel: bool = False,
224
        extra_tensors: list[torch.Tensor] | None = None,
225
    ) -> tuple[torch.Tensor, torch.Tensor]:
226
227
        raise NotImplementedError

228
229
230
    def combine(
        self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
231
232
233
234
235
236
237
238
        raise NotImplementedError

    def destroy(self):
        with self.handle_cache._lock:
            for _, handle in self.handle_cache._cache.items():
                handle.destroy()

        if self.internode:
239
240
241
            from pplx_kernels.nvshmem import (
                nvshmem_finalize,  # type: ignore[import-not-found]
            )
242

243
244
            logger.debug("PPLX NVSHMEM finalize")
            nvshmem_finalize()
245
246
247
248
249
250
251
252


class DeepEPAll2AllManagerBase(All2AllManagerBase):
    """
    All2All communication based on DeepEP High-Throughput kernels.
    """

    def __init__(self, cpu_group):
253
        assert has_deep_ep(), (
254
255
            "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
            " to install DeepEP kernels."
256
        )  # noqa
257
258
259
260
261
        super().__init__(cpu_group)
        self.handle_cache = Cache()

        # This is the DeepEP default. Stick to it till we can establish
        # reasonable defaults based on profiling.
262
        self.num_sms = 30
263
264
265
266

    def get_handle(self, kwargs):
        raise NotImplementedError

zhuwenwen's avatar
zhuwenwen committed
267
    def dispatch(
268
269
270
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
271
        is_sequence_parallel: bool = False,
272
        extra_tensors: list[torch.Tensor] | None = None,
273
    ) -> tuple[torch.Tensor, torch.Tensor]:
274
275
        raise NotImplementedError

276
277
278
    def combine(
        self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        raise NotImplementedError

    def destroy(self):
        pass


class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
    """
    All2All communication based on DeepEP High-Throughput kernels.
    """

    def __init__(self, cpu_group):
        super().__init__(cpu_group)

    def _make_all2all_kwargs(self) -> dict[Any, Any]:
        # Defaults for internode and intranode are taken from DeepEP tests.
295
296
        #num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
        num_nvl_bytes = int(2e9/2)#1024 * 1024 * 1024
297
298
299
        num_rdma_bytes = None
        num_qps_per_rank = None

300
        if self.internode and not envs.VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE:
301
302
303
304
305
            # num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
            # num_qps_per_rank = self.num_sms // 2
            num_rdma_bytes = int(1e9/2) #1024 * 1024 * 1024
            num_qps_per_rank = 30 #self.num_sms // 2
            self.num_sms = 30
306
307
308
        else:
            num_rdma_bytes = 0
            num_qps_per_rank = 1
309
            self.num_sms = 60
310
311
312

        assert num_rdma_bytes is not None
        assert num_qps_per_rank is not None
313
314
315
316
317
318
319
        return dict(
            group=self.cpu_group,
            num_nvl_bytes=num_nvl_bytes,
            num_rdma_bytes=num_rdma_bytes,
            low_latency_mode=False,
            num_qps_per_rank=num_qps_per_rank,
        )
320
321
322
323

    def get_handle(self, kwargs):
        assert len(kwargs) == 0, (
            "DeepEPHTAll2AllManager expects no arguments. All the required "
324
325
            "args are computed in the Manager itself."
        )
326

327
        import deep_ep  # type: ignore[import-not-found]
328

329
330
331
        buffer_kwargs = self._make_all2all_kwargs()
        logger.debug("DeepEP all2all args %s", buffer_kwargs)
        handle: deep_ep.Buffer = self.handle_cache.get_or_create(
332
333
            buffer_kwargs, deep_ep.Buffer
        )
334
335
        return handle

336
    def set_num_sms(self, num_sms: int):
337
        import deep_ep  # type: ignore[import-not-found]
338
339
340
341
342
343
344
345

        # Right now the buffers are sized for only what the kernels were
        # created with. So we can only reduce the number of SMS used
        # but not increase it.
        if num_sms > self.num_sms:
            num_sms = self.num_sms
        deep_ep.Buffer.set_num_sms(num_sms)

346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370

class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
    """
    All2All communication based on DeepEP Low-Latency kernels.
    """

    def __init__(self, cpu_group):
        super().__init__(cpu_group)

    def _make_all2all_kwargs(
        self,
        max_num_tokens_per_dp_rank: int,
        token_hidden_size: int,
        num_ep_ranks: int,
        num_global_experts: int,
        num_local_experts: int,
    ) -> dict[Any, Any]:
        """
        max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
          can dispatch all the ranks must hold the same value.
        token_hidden_size: the hidden dimension of each token.
        num_ep_ranks: the number of EP group ranks.
        num_global_experts: Number of experts in the model.
        num_local_experts: Number of experts in an EP rank.
        """
371
        import deep_ep  # type: ignore[import-not-found]
372
373

        # Defaults for internode and intranode are taken from DeepEP tests.
374
        num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
375
        num_qps_per_rank = num_local_experts
376
377
378
379
        num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
            num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
            hidden=token_hidden_size,
            num_ranks=num_ep_ranks,
380
381
            num_experts=num_global_experts,
        )
382
383

        assert num_rdma_bytes is not None
384
385
386
387
388
389
        return dict(
            group=self.cpu_group,
            num_nvl_bytes=num_nvl_bytes,
            num_rdma_bytes=num_rdma_bytes,
            low_latency_mode=True,
            num_qps_per_rank=num_qps_per_rank,
390
            allow_nvlink_for_low_latency_mode=True,
391
            allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
392
        )
393
394
395
396
397
398

    def get_handle(self, kwargs):
        """
        The kwargs for DeepEPLLAll2AllManager is dictated by
        _make_all2all_kwargs.
        """
399
        import deep_ep  # type: ignore[import-not-found]
400

401
402
403
        buffer_kwargs = self._make_all2all_kwargs(**kwargs)
        logger.debug("DeepEP all2all args %s", buffer_kwargs)
        handle: deep_ep.Buffer = self.handle_cache.get_or_create(
404
405
            buffer_kwargs, deep_ep.Buffer
        )
406
        return handle
407
408

    # DeepEP LL uses RDMA so no SMs are used for communication
409
    def max_sms_used(self) -> int | None:
410
411
412
413
414
415
416
417
        return 0


class FlashInferAllToAllManager(All2AllManagerBase):
    """
    All2All communication based on flashinfer kernels.
    """

418
419
420
421
422
    # This type lint could be removed after all of the work in
    # https://github.com/vllm-project/vllm/issues/26533 done.
    rank: int
    world_size: int

423
    def __init__(self, cpu_group):
424
425
426
        assert has_flashinfer_all2all(), (
            "flashinfer all2all module not found. Please install/check flashinfer"
        )  # noqa
427
428
        super().__init__(cpu_group)
        logger.debug(
429
430
431
432
            "Initialize for flashinfer All2All rank=%d, world size=%d",
            self.rank,
            self.world_size,
        )
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        self.initialized = False
        self.alltoall_info = None

    def initialize(
        self,
        world_size: int,
        rank: int,
        gpus_per_node: int,
    ):
        """Initialize workspace"""
        if self.initialized:
            return

        self.cleanup()
447
        logger.debug("making map: rank=%d, world size=%d", rank, world_size)
448
449
450
451
452
453
454
455
        self.mapping = Mapping(
            world_size,
            rank,
            gpus_per_node,
            tp_size=world_size,
        )

        from vllm.distributed.device_communicators.mnnvl_compat import (
456
457
458
            CustomCommunicator,
        )

459
460
461
        dp_config = MnnvlConfig(
            comm_backend=CustomCommunicator(get_dp_group().cpu_group),
            fabric_page_size=1 << 29,  # 512MB
462
            allocation_granularity=0,  # Auto-detect
463
464
        )

465
        self.workspace_tensor = MnnvlMoe.get_moe_workspaces(self.mapping, dp_config)
466
        self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
467
468
            self.mapping, dp_config
        )
469
470
471
472
473
474

        self.world_size = world_size
        self.rank = rank
        self.gpus_per_node = gpus_per_node
        self.initialized = True

475
476
477
        logger.info(
            "FlashInfer All2All initialized for rank %s, size %s", rank, world_size
        )
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499

    def ensure_alltoall_workspace_initialized(self):
        """Ensure workspace is initialized"""
        if not has_flashinfer_all2all():
            return False

        if self.world_size <= 1:
            return False

        if not self.initialized:
            self.initialize(
                world_size=self.world_size,
                rank=self.rank,
                gpus_per_node=torch.cuda.device_count,
            )
        return self.initialized

    def get_handle(self, kwargs):
        return self

    def cleanup(self):
        """Clean up workspace"""
500
501
502
503
504
        if (
            self.initialized
            and self.workspace_tensor is not None
            and self.prepare_workspace_tensor is not None
        ):
505
506
507
508
509
510
511
512
513
            try:
                del self.workspace_tensor
                del self.prepare_workspace_tensor
            except Exception as e:
                logger.warning("Failed to cleanup FlashInfer workspace: %s", e)
            finally:
                self.workspace_tensor = None
                self.prepare_workspace_tensor = None
                self.mapping = None
514
                self.initialized = False
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606


class MoriAll2AllManager(All2AllManagerBase):
    def __init__(self, cpu_group):
        assert has_mori(), (
            "MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md"
            " to install MoRI kernels."
        )  # noqa
        import mori

        super().__init__(cpu_group)
        self.handle_cache = Cache()

        torch._C._distributed_c10d._register_process_group("mori", cpu_group)
        mori.shmem.shmem_torch_process_group_init("mori")

    def _make_all2all_kwargs(
        self,
        rank: int,
        num_ep_ranks: int,
        input_dtype: torch.dtype,
        quant_dtype: torch.dtype,
        token_hidden_size: int,
        scale_dim: int,
        scale_type_size: int,
        max_num_tokens_per_dp_rank: int,
        num_local_experts: int,
        num_experts_per_token: int,
    ):
        import mori  # type: ignore[import-not-found]

        from vllm.platforms.rocm import on_gfx942, on_gfx950

        assert on_gfx942() or on_gfx950(), (
            "mori currently only support arch gfx942 and gfx950"
        )

        if not self.internode:
            # single node
            kernel_type = mori.ops.EpDispatchCombineKernelType.IntraNode
            rdma_block_num = 0
            warp_num_per_block = 16
            block_num = 80
        else:
            # multi node
            kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1
            if on_gfx942():
                warp_num_per_block = 16
                block_num = 32
                rdma_block_num = 16
            elif on_gfx950():
                warp_num_per_block = 8
                block_num = 64
                rdma_block_num = 32
            else:
                raise NotImplementedError(
                    "mori currently only support arch gfx942 and gfx950"
                )

        return dict(
            rank=rank,
            world_size=num_ep_ranks,
            data_type=quant_dtype,
            hidden_dim=token_hidden_size,
            scale_dim=scale_dim,
            scale_type_size=scale_type_size,
            max_token_type_size=input_dtype.itemsize,
            max_num_inp_token_per_rank=max_num_tokens_per_dp_rank,
            num_experts_per_rank=num_local_experts,
            num_experts_per_token=num_experts_per_token,
            warp_num_per_block=warp_num_per_block,
            block_num=block_num,
            kernel_type=kernel_type,
            rdma_block_num=rdma_block_num,
            gpu_per_node=min(8, num_ep_ranks),
        )

    def _make_handle(self, **kwargs):
        import mori  # type: ignore[import-not-found]

        mori_config = mori.ops.EpDispatchCombineConfig(**kwargs)
        handle = mori.ops.EpDispatchCombineOp(mori_config)
        return handle

    def get_handle(self, kwargs):
        import mori  # type: ignore[import-not-found]

        mori_kwargs = self._make_all2all_kwargs(**kwargs)
        logger.debug("MoRI all2all args %s", mori_kwargs)
        handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create(
            mori_kwargs, self._make_handle
        )
zhuwenwen's avatar
zhuwenwen committed
607
        return handle