all2all.py 12.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import TYPE_CHECKING, Any
4

5
import torch
6
import torch.distributed as dist
7
8

from vllm.forward_context import get_forward_context
9
from vllm.logger import init_logger
10
from vllm.utils import has_deep_ep, has_pplx
11
import vllm.envs as envs
12
from .base_device_communicator import All2AllManagerBase, Cache
13

14
logger = init_logger(__name__)
15

16
17
18
19
if TYPE_CHECKING:
    from vllm.model_executor.layers.fused_moe.layer import FusedMoE
else:
    FusedMoE = None
20
21


22
class NaiveAll2AllManager(All2AllManagerBase):
23
24
25
26
27
28
29
    """
    A naive implementation of all2all communication.
    It uses all-reduce under the hood, which is not
    efficient at all. The main purpose is for testing and
    debugging.
    """

30
31
    def __init__(self, cpu_group):
        super().__init__(cpu_group)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    def naive_multicast(self, x: torch.Tensor,
                        cu_tokens_across_dp_cpu: torch.Tensor):
        assert (len(x.shape) == 2)
        buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
                             device=x.device,
                             dtype=x.dtype)

        start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
            self.dp_rank - 1]
        end = cu_tokens_across_dp_cpu[self.dp_rank]
        buffer[start:end, :].copy_(x)
        for idx in range(self.dp_world_size):
            start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
            end = cu_tokens_across_dp_cpu[idx]
            self.dp_group.broadcast(buffer[start:end, :], idx)

        return buffer

    def dispatch(self, hidden_states: torch.Tensor,
                 router_logits: torch.Tensor):
        cu_tokens_across_dp_cpu = get_forward_context(
        ).dp_metadata.cu_tokens_across_dp_cpu

        hidden_states = self.naive_multicast(hidden_states,
                                             cu_tokens_across_dp_cpu)
        router_logits = self.naive_multicast(router_logits,
                                             cu_tokens_across_dp_cpu)
        return hidden_states, router_logits

    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
        cu_tokens_across_dp_cpu = get_forward_context(
        ).dp_metadata.cu_tokens_across_dp_cpu
        start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
            self.dp_rank - 1]
        end = cu_tokens_across_dp_cpu[self.dp_rank]

        all_hidden_states = self.dp_group.all_reduce(hidden_states)
        hidden_states = all_hidden_states[start:end, :]
        return hidden_states

    def destroy(self):
        pass
75
76
77
78
79
80
81
82


class PPLXAll2AllManager(All2AllManagerBase):
    """
    All2All communication based on PPLX kernels.
    """

    def __init__(self, cpu_group):
83
84
        assert has_pplx(
        ), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels."  # noqa
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        super().__init__(cpu_group)

        if self.internode:
            # inter-node communication needs nvshmem,
            # intra-node communication uses p2p mapping directly
            from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
                                              nvshmem_get_unique_id,
                                              nvshmem_init)
            logger.debug(
                "Initialize NVSHMEM for pplx_kernels: "
                "rank=%d, world size=%d", self.rank, self.world_size)
            uid = nvshmem_get_unique_id(
            ) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
            dist.broadcast(uid,
                           src=dist.get_process_group_ranks(self.cpu_group)[0],
                           group=self.cpu_group)
            logger.debug("PPLX NVSHMEM UID = %s", uid)
            nvshmem_init(uid, self.rank, self.world_size)

        self.handle_cache = Cache()

    def get_handle(self, kwargs):
        import pplx_kernels as pplx
        return self.handle_cache.get_or_create(
            kwargs, pplx.AllToAll.internode
            if self.internode else pplx.AllToAll.intranode)

    def dispatch(self, hidden_states: torch.Tensor,
                 router_logits: torch.Tensor):
        raise NotImplementedError

    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def destroy(self):
        with self.handle_cache._lock:
            for _, handle in self.handle_cache._cache.items():
                handle.destroy()

        if self.internode:
            from pplx_kernels.nvshmem import nvshmem_finalize
            logger.debug("PPLX NVSHMEM finalize")
            nvshmem_finalize()
128
129
130
131
132
133
134
135


class DeepEPAll2AllManagerBase(All2AllManagerBase):
    """
    All2All communication based on DeepEP High-Throughput kernels.
    """

    def __init__(self, cpu_group):
136
137
        assert has_deep_ep(
        ), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels."  # noqa
138
139
140
141
142
        super().__init__(cpu_group)
        self.handle_cache = Cache()

        # This is the DeepEP default. Stick to it till we can establish
        # reasonable defaults based on profiling.
143
        self.num_sms = 24#20
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

    def get_handle(self, kwargs):
        raise NotImplementedError

    def dispatch(self, hidden_states: torch.Tensor,
                 router_logits: torch.Tensor):
        raise NotImplementedError

    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def destroy(self):
        pass


class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
    """
    All2All communication based on DeepEP High-Throughput kernels.
    """

    def __init__(self, cpu_group):
        super().__init__(cpu_group)

    def _make_all2all_kwargs(self) -> dict[Any, Any]:
        # Defaults for internode and intranode are taken from DeepEP tests.
169
        num_nvl_bytes = int(2e9/2)#1024 * 1024 * 1024
170
171
172
173
        num_rdma_bytes = None
        num_qps_per_rank = None

        if self.internode:
王敏's avatar
王敏 committed
174
175
176
177
178
179
180
181
182
183
            num_rdma_bytes = int(1e9/2) #1024 * 1024 * 1024
            num_qps_per_rank = 30 #self.num_sms // 2

            # import deep_ep
            # num_nvl_bytes, num_rdma_bytes = 0, 0
            # hidden_size = 7168
            # hidden_bytes = hidden_size * 2
            # for config in (deep_ep.Buffer.get_dispatch_config(self.cpu_group.size()), deep_ep.Buffer.get_combine_config(self.cpu_group.size())):
            #     num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_nvl_bytes)
            #     num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_rdma_bytes)
184
185
186
187
188
189
190
191
192
193
        else:
            num_rdma_bytes = 0
            num_qps_per_rank = 1

        assert num_rdma_bytes is not None
        assert num_qps_per_rank is not None
        return dict(group=self.cpu_group,
                    num_nvl_bytes=num_nvl_bytes,
                    num_rdma_bytes=num_rdma_bytes,
                    low_latency_mode=False,
194
                    num_qps_per_rank=num_qps_per_rank,
195
                    explicitly_destroy=False)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

    def get_handle(self, kwargs):

        assert len(kwargs) == 0, (
            "DeepEPHTAll2AllManager expects no arguments. All the required "
            "args are computed in the Manager itself.")

        import deep_ep
        buffer_kwargs = self._make_all2all_kwargs()
        logger.debug("DeepEP all2all args %s", buffer_kwargs)
        handle: deep_ep.Buffer = self.handle_cache.get_or_create(
            buffer_kwargs, deep_ep.Buffer)
        # It is dangerous to set num sms outside this function. num_sms is not
        # a part of the hash-key that identifies this object. If we are in a
        # situation where we make objects with different num_sms, the hash key
        # in get_or_create must be updated.
        handle.set_num_sms(self.num_sms)
        return handle


class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
    """
    All2All communication based on DeepEP Low-Latency kernels.
    """

    def __init__(self, cpu_group):
        super().__init__(cpu_group)

    def _make_all2all_kwargs(
        self,
        max_num_tokens_per_dp_rank: int,
        token_hidden_size: int,
        num_ep_ranks: int,
        num_global_experts: int,
        num_local_experts: int,
    ) -> dict[Any, Any]:
        """
        max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
          can dispatch all the ranks must hold the same value.
        token_hidden_size: the hidden dimension of each token.
        num_ep_ranks: the number of EP group ranks.
        num_global_experts: Number of experts in the model.
        num_local_experts: Number of experts in an EP rank.
        """
        import deep_ep

        # Defaults for internode and intranode are taken from DeepEP tests.
        num_nvl_bytes = 1024 * 1024 * 1024
        num_qps_per_rank = num_local_experts
245
246
247
248
249
        num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
            num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
            hidden=token_hidden_size,
            num_ranks=num_ep_ranks,
            num_experts=num_global_experts)
250
251
252
253
254
255

        assert num_rdma_bytes is not None
        return dict(group=self.cpu_group,
                    num_nvl_bytes=num_nvl_bytes,
                    num_rdma_bytes=num_rdma_bytes,
                    low_latency_mode=True,
256
257
258
                    num_qps_per_rank=num_qps_per_rank,
                    allow_mnnvl=envs.VLLM_ALLOW_MNNVL,
                    )
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

    def get_handle(self, kwargs):
        """
        The kwargs for DeepEPLLAll2AllManager is dictated by
        _make_all2all_kwargs.
        """
        import deep_ep
        buffer_kwargs = self._make_all2all_kwargs(**kwargs)
        logger.debug("DeepEP all2all args %s", buffer_kwargs)
        handle: deep_ep.Buffer = self.handle_cache.get_or_create(
            buffer_kwargs, deep_ep.Buffer)
        # It is dangerous to set num sms outside this function. num_sms is not
        # a part of the hash-key that identifies this object. If we are in a
        # situation where we make objects with different num_sms, the hash key
        # in get_or_create must be updated.
        handle.set_num_sms(self.num_sms)
        return handle
yangql's avatar
yangql committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328


class DeepEPAutoAll2AllManager(All2AllManagerBase):
    """
    Simplified auto manager that always builds handles through the
    low-latency DeepEP manager. This avoids creating multiple buffer
    instances and mirrors the sglang behavior of relying on LL buffers.
    """

    def __init__(self, cpu_group):
        super().__init__(cpu_group)
        self.ll_manager = DeepEPLLAll2AllManager(cpu_group)
        self.ht_manager = DeepEPHTAll2AllManager(cpu_group)
    def get_handle(self, kwargs):
        """
        Build a DeepEP Buffer using LL args but sized to the larger of HT/LL
        requirements (max of num_nvl_bytes/num_rdma_bytes).
        """
        import deep_ep

        kwargs = dict(kwargs)
        
        # Build canonical kwargs for each path.
        ll_kwargs = self.ll_manager._make_all2all_kwargs(**kwargs)
        ht_kwargs = self.ht_manager._make_all2all_kwargs()

        # Take the max for buffer sizes to be compatible with both modes.
        merged_kwargs = dict(ll_kwargs)
        merged_kwargs["num_nvl_bytes"] = max(ll_kwargs["num_nvl_bytes"],
                                             ht_kwargs["num_nvl_bytes"])
        merged_kwargs["num_rdma_bytes"] = max(ll_kwargs["num_rdma_bytes"],
                                              ht_kwargs["num_rdma_bytes"])

        logger.debug("DeepEP auto merged args %s", merged_kwargs)
        handle: deep_ep.Buffer = self.ll_manager.handle_cache.get_or_create(
            merged_kwargs, deep_ep.Buffer)
        handle.set_num_sms(self.ll_manager.num_sms)
        return handle

    def dispatch(self, hidden_states: torch.Tensor,
                 router_logits: torch.Tensor):
        raise NotImplementedError(
            "DeepEPAutoAll2AllManager does not support dispatch directly; "
            "use the underlying HT/LL managers.")

    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError(
            "DeepEPAutoAll2AllManager does not support combine directly; "
            "use the underlying HT/LL managers.")

    def destroy(self):
        self.ll_manager.destroy()