parallel_utils.py 5.02 KB
Newer Older
bnellnm's avatar
bnellnm committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
bnellnm's avatar
bnellnm committed
3
4
5
"""
DeepEP test utilities
"""
6

bnellnm's avatar
bnellnm committed
7
8
9
import dataclasses
import os
import traceback
10
11
from collections.abc import Callable
from typing import Concatenate
bnellnm's avatar
bnellnm committed
12
13
14

import torch
from torch.distributed import ProcessGroup
15
from torch.multiprocessing import spawn  # pyright: ignore[reportPrivateImportUsage]
16
from typing_extensions import ParamSpec
bnellnm's avatar
bnellnm committed
17

18
from vllm.utils.import_utils import has_deep_ep
19
from vllm.utils.network_utils import get_open_port
bnellnm's avatar
bnellnm committed
20

21
if has_deep_ep():
22
    from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
23
24
        DeepEPHTPrepareAndFinalize,
    )
25
    from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
26
27
        DeepEPLLPrepareAndFinalize,
    )
bnellnm's avatar
bnellnm committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

## Parallel Processes Utils

P = ParamSpec("P")


@dataclasses.dataclass
class ProcessGroupInfo:
    world_size: int
    world_local_size: int
    rank: int
    node_rank: int
    local_rank: int
    device: torch.device


def _worker_parallel_launch(
    local_rank: int,
    world_size: int,
    world_local_size: int,
    node_rank: int,
    init_method: str,
    worker: Callable[Concatenate[ProcessGroupInfo, P], None],
    *args: P.args,
    **kwargs: P.kwargs,
) -> None:
    rank = node_rank * world_local_size + local_rank
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    torch.distributed.init_process_group(
        backend="cpu:gloo,cuda:nccl",
        init_method=init_method,
        rank=rank,
        world_size=world_size,
        device_id=device,
    )
    barrier = torch.tensor([rank], device=device)
    torch.distributed.all_reduce(barrier)

    try:
        worker(
            ProcessGroupInfo(
                world_size=world_size,
                world_local_size=world_local_size,
                rank=rank,
                node_rank=node_rank,
                local_rank=local_rank,
                device=device,
            ),
            *args,
            **kwargs,
        )
    except Exception as ex:
        print(ex)
        traceback.print_exc()
        raise
    finally:
        torch.distributed.destroy_process_group()


def parallel_launch(
    world_size: int,
    worker: Callable[Concatenate[ProcessGroupInfo, P], None],
    *args: P.args,
    **kwargs: P.kwargs,
) -> None:
    assert not kwargs
    spawn(
        _worker_parallel_launch,
        args=(
            world_size,
            world_size,
            0,
            f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
            worker,
103
104
        )
        + args,
bnellnm's avatar
bnellnm committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        nprocs=world_size,
        join=True,
    )


## DeepEP specific utils


@dataclasses.dataclass
class DeepEPHTArgs:
    num_local_experts: int


@dataclasses.dataclass
class DeepEPLLArgs:
    max_tokens_per_rank: int
    hidden_size: int
    num_experts: int
    use_fp8_dispatch: bool


126
127
128
129
130
def make_deepep_ht_a2a(
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    ht_args: DeepEPHTArgs,
131
132
    q_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
133
):
bnellnm's avatar
bnellnm committed
134
135
136
137
138
    import deep_ep

    # high throughput a2a
    num_nvl_bytes = 1024 * 1024 * 1024  # 1GB
    num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    buffer = deep_ep.Buffer(
        group=pg,
        num_nvl_bytes=num_nvl_bytes,
        num_rdma_bytes=num_rdma_bytes,
        low_latency_mode=low_latency_mode,
        num_qps_per_rank=num_qps_per_rank,
    )
    return DeepEPHTPrepareAndFinalize(
        buffer=buffer,
        num_dispatchers=pgi.world_size,
        dp_size=dp_size,
        rank_expert_offset=pgi.rank * ht_args.num_local_experts,
    )

bnellnm's avatar
bnellnm committed
153

154
155
156
157
def make_deepep_ll_a2a(
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    deepep_ll_args: DeepEPLLArgs,
158
159
    q_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
160
):
bnellnm's avatar
bnellnm committed
161
162
163
164
    import deep_ep

    # low-latency a2a
    num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
165
166
167
168
169
        deepep_ll_args.max_tokens_per_rank,
        deepep_ll_args.hidden_size,
        pgi.world_size,
        deepep_ll_args.num_experts,
    )
bnellnm's avatar
bnellnm committed
170

171
172
173
174
175
176
    buffer = deep_ep.Buffer(
        group=pg,
        num_rdma_bytes=num_rdma_bytes,
        low_latency_mode=True,
        num_qps_per_rank=deepep_ll_args.num_experts // pgi.world_size,
    )
bnellnm's avatar
bnellnm committed
177
178
179

    return DeepEPLLPrepareAndFinalize(
        buffer=buffer,
180
        num_dispatchers=pgi.world_size,
bnellnm's avatar
bnellnm committed
181
182
183
184
185
        max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
        use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
    )


186
187
188
189
def make_deepep_a2a(
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
190
191
192
193
    deepep_ht_args: DeepEPHTArgs | None,
    deepep_ll_args: DeepEPLLArgs | None,
    q_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
194
):
bnellnm's avatar
bnellnm committed
195
196
    if deepep_ht_args is not None:
        assert deepep_ll_args is None
197
198
199
        return make_deepep_ht_a2a(
            pg, pgi, dp_size, deepep_ht_args, q_dtype, block_shape
        )
bnellnm's avatar
bnellnm committed
200
201

    assert deepep_ll_args is not None
202
    return make_deepep_ll_a2a(pg, pgi, deepep_ll_args, q_dtype, block_shape)