all2all.py 16.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any
4

5
import torch
6
import torch.distributed as dist
7

8
import vllm.envs as envs
9
from vllm.distributed import get_dp_group, get_ep_group
10
from vllm.forward_context import get_forward_context
11
from vllm.logger import init_logger
12
from vllm.utils.flashinfer import has_flashinfer_all2all
13
from vllm.utils.import_utils import has_deep_ep, has_pplx
14

15
from .base_device_communicator import All2AllManagerBase, Cache
16

17
if has_flashinfer_all2all():
18
19
20
21
22
    from flashinfer.comm import Mapping  # type: ignore[import-not-found]
    from flashinfer.comm.mnnvl import MnnvlConfig  # type: ignore[import-not-found]
    from flashinfer.comm.trtllm_alltoall import (
        MnnvlMoe,  # type: ignore[import-not-found]
    )
23

24
logger = init_logger(__name__)
25
26


27
class NaiveAll2AllManager(All2AllManagerBase):
28
29
30
31
32
33
34
    """
    A naive implementation of all2all communication.
    It uses all-reduce under the hood, which is not
    efficient at all. The main purpose is for testing and
    debugging.
    """

35
36
    def __init__(self, cpu_group):
        super().__init__(cpu_group)
37

38
39
40
41
42
43
44
45
46
47
    def naive_multicast(
        self,
        x: torch.Tensor,
        cu_tokens_across_sp_cpu: torch.Tensor,
        is_sequence_parallel: bool,
    ) -> torch.Tensor:
        assert len(x.shape) == 2
        buffer = torch.empty(
            (cu_tokens_across_sp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype
        )
48

49
        rank = self.rank if is_sequence_parallel else self.dp_rank
50
        world_size = self.world_size if is_sequence_parallel else self.dp_world_size
51
52
53

        start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
        end = cu_tokens_across_sp_cpu[rank]
54
        buffer[start:end, :].copy_(x)
55
56
57
58
        for idx in range(world_size):
            start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
            end = cu_tokens_across_sp_cpu[idx]
            get_ep_group().broadcast(buffer[start:end, :], idx)
59
60
61

        return buffer

62
63
64
65
    def dispatch(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
66
        is_sequence_parallel: bool = False,
67
        extra_tensors: list[torch.Tensor] | None = None,
68
    ) -> tuple[torch.Tensor, torch.Tensor]:
69
70
71
72
        if extra_tensors is not None:
            raise NotImplementedError(
                "extra_tensors is not supported for NaiveAll2AllManager"
            )
73
74
        sp_size = self.tp_group.world_size if is_sequence_parallel else 1
        dp_metadata = get_forward_context().dp_metadata
75
        assert dp_metadata is not None
76
77
        cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

78
79
80
81
82
83
        hidden_states = self.naive_multicast(
            hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
        )
        router_logits = self.naive_multicast(
            router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
        )
84

85
86
        return hidden_states, router_logits

87
88
89
    def combine(
        self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
90
91
92
        ep_rank = self.rank if is_sequence_parallel else self.dp_rank

        dp_metadata = get_forward_context().dp_metadata
93
        assert dp_metadata is not None
94
95
96
97
98
99
100
101
        sp_size = self.tp_group.world_size if is_sequence_parallel else 1
        cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

        start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
        end = cu_tokens_across_sp_cpu[ep_rank]

        all_hidden_states = get_ep_group().all_reduce(hidden_states)
        hidden_states = all_hidden_states[start:end, :]
102
103
104
105
        return hidden_states

    def destroy(self):
        pass
106
107


108
109
110
111
112
113
114
115
116
class AgRsAll2AllManager(All2AllManagerBase):
    """
    An implementation of all2all communication based on
    all-gather (dispatch) and reduce-scatter (combine).
    """

    def __init__(self, cpu_group):
        super().__init__(cpu_group)

117
118
119
120
    def dispatch(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
121
        is_sequence_parallel: bool = False,
122
123
124
125
126
        extra_tensors: list[torch.Tensor] | None = None,
    ) -> (
        tuple[torch.Tensor, torch.Tensor]
        | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
    ):
127
128
129
        """
        Gather hidden_states and router_logits from all dp ranks.
        """
130
131
132
133
        dp_metadata = get_forward_context().dp_metadata
        assert dp_metadata is not None
        sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
        assert sizes is not None
134
135
        dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
        assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
136
137
138
139
140
141
142

        tensors_to_gather = [hidden_states, router_logits]
        if extra_tensors is not None:
            tensors_to_gather.extend(extra_tensors)

        gathered_tensors = dist_group.all_gatherv(
            tensors_to_gather,
143
144
145
            dim=0,
            sizes=sizes,
        )
146
147
148
149

        if extra_tensors is not None:
            return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
        return gathered_tensors[0], gathered_tensors[1]
150

151
152
153
    def combine(
        self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
154
155
156
        """
        Reduce-scatter hidden_states across all dp ranks.
        """
157
158
159
160
        dp_metadata = get_forward_context().dp_metadata
        assert dp_metadata is not None
        sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
        assert sizes is not None
161
162

        dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
163
        hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
164
165
166
167
168
169
        return hidden_states

    def destroy(self):
        pass


170
171
172
173
174
175
class PPLXAll2AllManager(All2AllManagerBase):
    """
    All2All communication based on PPLX kernels.
    """

    def __init__(self, cpu_group):
176
        assert has_pplx(), (
177
178
179
            "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
            " to install pplx_kernels."
        )
180
181
182
183
184
        super().__init__(cpu_group)

        if self.internode:
            # inter-node communication needs nvshmem,
            # intra-node communication uses p2p mapping directly
185
            from pplx_kernels.nvshmem import (  # type: ignore[import-not-found]
186
187
188
189
190
                nvshmem_alloc_empty_unique_id,
                nvshmem_get_unique_id,
                nvshmem_init,
            )

191
            logger.debug(
192
193
194
195
196
197
198
199
200
201
202
203
204
205
                "Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
                self.rank,
                self.world_size,
            )
            uid = (
                nvshmem_get_unique_id()
                if self.rank == 0
                else nvshmem_alloc_empty_unique_id()
            )
            dist.broadcast(
                uid,
                src=dist.get_process_group_ranks(self.cpu_group)[0],
                group=self.cpu_group,
            )
206
207
208
209
210
211
            logger.debug("PPLX NVSHMEM UID = %s", uid)
            nvshmem_init(uid, self.rank, self.world_size)

        self.handle_cache = Cache()

    def get_handle(self, kwargs):
212
        import pplx_kernels as pplx  # type: ignore[import-not-found]
213

214
        return self.handle_cache.get_or_create(
215
216
217
            kwargs,
            pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
        )
218

219
220
221
222
    def dispatch(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
223
        is_sequence_parallel: bool = False,
224
        extra_tensors: list[torch.Tensor] | None = None,
225
    ) -> tuple[torch.Tensor, torch.Tensor]:
226
227
        raise NotImplementedError

228
229
230
    def combine(
        self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
231
232
233
234
235
236
237
238
        raise NotImplementedError

    def destroy(self):
        with self.handle_cache._lock:
            for _, handle in self.handle_cache._cache.items():
                handle.destroy()

        if self.internode:
239
240
241
            from pplx_kernels.nvshmem import (
                nvshmem_finalize,  # type: ignore[import-not-found]
            )
242

243
244
            logger.debug("PPLX NVSHMEM finalize")
            nvshmem_finalize()
245
246
247
248
249
250
251
252


class DeepEPAll2AllManagerBase(All2AllManagerBase):
    """
    All2All communication based on DeepEP High-Throughput kernels.
    """

    def __init__(self, cpu_group):
253
        assert has_deep_ep(), (
254
255
            "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
            " to install DeepEP kernels."
256
        )  # noqa
257
258
259
260
261
        super().__init__(cpu_group)
        self.handle_cache = Cache()

        # This is the DeepEP default. Stick to it till we can establish
        # reasonable defaults based on profiling.
262
        self.num_sms = 30
263
264
265
266

    def get_handle(self, kwargs):
        raise NotImplementedError

267
268
269
270
    def dispatch(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
271
        is_sequence_parallel: bool = False,
272
        extra_tensors: list[torch.Tensor] | None = None,
273
    ) -> tuple[torch.Tensor, torch.Tensor]:
274
275
        raise NotImplementedError

276
277
278
    def combine(
        self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        raise NotImplementedError

    def destroy(self):
        pass


class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
    """
    All2All communication based on DeepEP High-Throughput kernels.
    """

    def __init__(self, cpu_group):
        super().__init__(cpu_group)

    def _make_all2all_kwargs(self) -> dict[Any, Any]:
        # Defaults for internode and intranode are taken from DeepEP tests.
295
296
        #num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
        num_nvl_bytes = int(2e9/2)#1024 * 1024 * 1024
297
298
299
        num_rdma_bytes = None
        num_qps_per_rank = None

300
        if self.internode and not envs.VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE:
301
302
303
304
305
            # num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
            # num_qps_per_rank = self.num_sms // 2
            num_rdma_bytes = int(1e9/2) #1024 * 1024 * 1024
            num_qps_per_rank = 30 #self.num_sms // 2
            self.num_sms = 30
306
307
308
        else:
            num_rdma_bytes = 0
            num_qps_per_rank = 1
309
            self.num_sms = 60
310
311
312

        assert num_rdma_bytes is not None
        assert num_qps_per_rank is not None
313
314
315
316
317
318
319
        return dict(
            group=self.cpu_group,
            num_nvl_bytes=num_nvl_bytes,
            num_rdma_bytes=num_rdma_bytes,
            low_latency_mode=False,
            num_qps_per_rank=num_qps_per_rank,
        )
320
321
322
323

    def get_handle(self, kwargs):
        assert len(kwargs) == 0, (
            "DeepEPHTAll2AllManager expects no arguments. All the required "
324
325
            "args are computed in the Manager itself."
        )
326

327
        import deep_ep  # type: ignore[import-not-found]
328

329
330
331
        buffer_kwargs = self._make_all2all_kwargs()
        logger.debug("DeepEP all2all args %s", buffer_kwargs)
        handle: deep_ep.Buffer = self.handle_cache.get_or_create(
332
333
            buffer_kwargs, deep_ep.Buffer
        )
334
335
        return handle

336
    def set_num_sms(self, num_sms: int):
337
        import deep_ep  # type: ignore[import-not-found]
338
339
340
341
342
343
344
345

        # Right now the buffers are sized for only what the kernels were
        # created with. So we can only reduce the number of SMS used
        # but not increase it.
        if num_sms > self.num_sms:
            num_sms = self.num_sms
        deep_ep.Buffer.set_num_sms(num_sms)

346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370

class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
    """
    All2All communication based on DeepEP Low-Latency kernels.
    """

    def __init__(self, cpu_group):
        super().__init__(cpu_group)

    def _make_all2all_kwargs(
        self,
        max_num_tokens_per_dp_rank: int,
        token_hidden_size: int,
        num_ep_ranks: int,
        num_global_experts: int,
        num_local_experts: int,
    ) -> dict[Any, Any]:
        """
        max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
          can dispatch all the ranks must hold the same value.
        token_hidden_size: the hidden dimension of each token.
        num_ep_ranks: the number of EP group ranks.
        num_global_experts: Number of experts in the model.
        num_local_experts: Number of experts in an EP rank.
        """
371
        import deep_ep  # type: ignore[import-not-found]
372
373

        # Defaults for internode and intranode are taken from DeepEP tests.
374
        num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
375
        num_qps_per_rank = num_local_experts
376
377
378
379
        num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
            num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
            hidden=token_hidden_size,
            num_ranks=num_ep_ranks,
380
381
            num_experts=num_global_experts,
        )
382
383

        assert num_rdma_bytes is not None
384
385
386
387
388
389
        return dict(
            group=self.cpu_group,
            num_nvl_bytes=num_nvl_bytes,
            num_rdma_bytes=num_rdma_bytes,
            low_latency_mode=True,
            num_qps_per_rank=num_qps_per_rank,
390
            allow_nvlink_for_low_latency_mode=True,
391
            allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
392
        )
393
394
395
396
397
398

    def get_handle(self, kwargs):
        """
        The kwargs for DeepEPLLAll2AllManager is dictated by
        _make_all2all_kwargs.
        """
399
        import deep_ep  # type: ignore[import-not-found]
400

401
402
403
        buffer_kwargs = self._make_all2all_kwargs(**kwargs)
        logger.debug("DeepEP all2all args %s", buffer_kwargs)
        handle: deep_ep.Buffer = self.handle_cache.get_or_create(
404
405
            buffer_kwargs, deep_ep.Buffer
        )
406
        return handle
407
408

    # DeepEP LL uses RDMA so no SMs are used for communication
409
    def max_sms_used(self) -> int | None:
410
411
412
413
414
415
416
417
        return 0


class FlashInferAllToAllManager(All2AllManagerBase):
    """
    All2All communication based on flashinfer kernels.
    """

418
419
420
421
422
    # This type lint could be removed after all of the work in
    # https://github.com/vllm-project/vllm/issues/26533 done.
    rank: int
    world_size: int

423
    def __init__(self, cpu_group):
424
425
426
        assert has_flashinfer_all2all(), (
            "flashinfer all2all module not found. Please install/check flashinfer"
        )  # noqa
427
428
        super().__init__(cpu_group)
        logger.debug(
429
430
431
432
            "Initialize for flashinfer All2All rank=%d, world size=%d",
            self.rank,
            self.world_size,
        )
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        self.initialized = False
        self.alltoall_info = None

    def initialize(
        self,
        world_size: int,
        rank: int,
        gpus_per_node: int,
    ):
        """Initialize workspace"""
        if self.initialized:
            return

        self.cleanup()
447
        logger.debug("making map: rank=%d, world size=%d", rank, world_size)
448
449
450
451
452
453
454
455
        self.mapping = Mapping(
            world_size,
            rank,
            gpus_per_node,
            tp_size=world_size,
        )

        from vllm.distributed.device_communicators.mnnvl_compat import (
456
457
458
            CustomCommunicator,
        )

459
460
461
        dp_config = MnnvlConfig(
            comm_backend=CustomCommunicator(get_dp_group().cpu_group),
            fabric_page_size=1 << 29,  # 512MB
462
            allocation_granularity=0,  # Auto-detect
463
464
        )

465
        self.workspace_tensor = MnnvlMoe.get_moe_workspaces(self.mapping, dp_config)
466
        self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
467
468
            self.mapping, dp_config
        )
469
470
471
472
473
474

        self.world_size = world_size
        self.rank = rank
        self.gpus_per_node = gpus_per_node
        self.initialized = True

475
476
477
        logger.info(
            "FlashInfer All2All initialized for rank %s, size %s", rank, world_size
        )
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499

    def ensure_alltoall_workspace_initialized(self):
        """Ensure workspace is initialized"""
        if not has_flashinfer_all2all():
            return False

        if self.world_size <= 1:
            return False

        if not self.initialized:
            self.initialize(
                world_size=self.world_size,
                rank=self.rank,
                gpus_per_node=torch.cuda.device_count,
            )
        return self.initialized

    def get_handle(self, kwargs):
        return self

    def cleanup(self):
        """Clean up workspace"""
500
501
502
503
504
        if (
            self.initialized
            and self.workspace_tensor is not None
            and self.prepare_workspace_tensor is not None
        ):
505
506
507
508
509
510
511
512
513
            try:
                del self.workspace_tensor
                del self.prepare_workspace_tensor
            except Exception as e:
                logger.warning("Failed to cleanup FlashInfer workspace: %s", e)
            finally:
                self.workspace_tensor = None
                self.prepare_workspace_tensor = None
                self.mapping = None
514
                self.initialized = False