parallel_utils.py 4.97 KB
Newer Older
bnellnm's avatar
bnellnm committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
bnellnm's avatar
bnellnm committed
3
4
5
"""
DeepEP test utilities
"""
6

bnellnm's avatar
bnellnm committed
7
8
9
import dataclasses
import os
import traceback
10
11
from collections.abc import Callable
from typing import Concatenate
bnellnm's avatar
bnellnm committed
12
13
14

import torch
from torch.distributed import ProcessGroup
15
from torch.multiprocessing import spawn  # pyright: ignore[reportPrivateImportUsage]
16
from typing_extensions import ParamSpec
bnellnm's avatar
bnellnm committed
17

18
from vllm.utils import get_open_port, has_deep_ep
bnellnm's avatar
bnellnm committed
19

20
if has_deep_ep():
21
    from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
22
23
        DeepEPHTPrepareAndFinalize,
    )
24
    from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
25
26
        DeepEPLLPrepareAndFinalize,
    )
bnellnm's avatar
bnellnm committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

## Parallel Processes Utils

P = ParamSpec("P")


@dataclasses.dataclass
class ProcessGroupInfo:
    world_size: int
    world_local_size: int
    rank: int
    node_rank: int
    local_rank: int
    device: torch.device


def _worker_parallel_launch(
    local_rank: int,
    world_size: int,
    world_local_size: int,
    node_rank: int,
    init_method: str,
    worker: Callable[Concatenate[ProcessGroupInfo, P], None],
    *args: P.args,
    **kwargs: P.kwargs,
) -> None:
    rank = node_rank * world_local_size + local_rank
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    torch.distributed.init_process_group(
        backend="cpu:gloo,cuda:nccl",
        init_method=init_method,
        rank=rank,
        world_size=world_size,
        device_id=device,
    )
    barrier = torch.tensor([rank], device=device)
    torch.distributed.all_reduce(barrier)

    try:
        worker(
            ProcessGroupInfo(
                world_size=world_size,
                world_local_size=world_local_size,
                rank=rank,
                node_rank=node_rank,
                local_rank=local_rank,
                device=device,
            ),
            *args,
            **kwargs,
        )
    except Exception as ex:
        print(ex)
        traceback.print_exc()
        raise
    finally:
        torch.distributed.destroy_process_group()


def parallel_launch(
    world_size: int,
    worker: Callable[Concatenate[ProcessGroupInfo, P], None],
    *args: P.args,
    **kwargs: P.kwargs,
) -> None:
    assert not kwargs
    spawn(
        _worker_parallel_launch,
        args=(
            world_size,
            world_size,
            0,
            f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
            worker,
102
103
        )
        + args,
bnellnm's avatar
bnellnm committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        nprocs=world_size,
        join=True,
    )


## DeepEP specific utils


@dataclasses.dataclass
class DeepEPHTArgs:
    num_local_experts: int


@dataclasses.dataclass
class DeepEPLLArgs:
    max_tokens_per_rank: int
    hidden_size: int
    num_experts: int
    use_fp8_dispatch: bool


125
126
127
128
129
def make_deepep_ht_a2a(
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    ht_args: DeepEPHTArgs,
130
131
    q_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
132
):
bnellnm's avatar
bnellnm committed
133
134
135
136
137
    import deep_ep

    # high throughput a2a
    num_nvl_bytes = 1024 * 1024 * 1024  # 1GB
    num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    buffer = deep_ep.Buffer(
        group=pg,
        num_nvl_bytes=num_nvl_bytes,
        num_rdma_bytes=num_rdma_bytes,
        low_latency_mode=low_latency_mode,
        num_qps_per_rank=num_qps_per_rank,
    )
    return DeepEPHTPrepareAndFinalize(
        buffer=buffer,
        num_dispatchers=pgi.world_size,
        dp_size=dp_size,
        rank_expert_offset=pgi.rank * ht_args.num_local_experts,
    )

bnellnm's avatar
bnellnm committed
152

153
154
155
156
def make_deepep_ll_a2a(
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    deepep_ll_args: DeepEPLLArgs,
157
158
    q_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
159
):
bnellnm's avatar
bnellnm committed
160
161
162
163
    import deep_ep

    # low-latency a2a
    num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
164
165
166
167
168
        deepep_ll_args.max_tokens_per_rank,
        deepep_ll_args.hidden_size,
        pgi.world_size,
        deepep_ll_args.num_experts,
    )
bnellnm's avatar
bnellnm committed
169

170
171
172
173
174
175
    buffer = deep_ep.Buffer(
        group=pg,
        num_rdma_bytes=num_rdma_bytes,
        low_latency_mode=True,
        num_qps_per_rank=deepep_ll_args.num_experts // pgi.world_size,
    )
bnellnm's avatar
bnellnm committed
176
177
178

    return DeepEPLLPrepareAndFinalize(
        buffer=buffer,
179
        num_dispatchers=pgi.world_size,
bnellnm's avatar
bnellnm committed
180
181
182
183
184
        max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
        use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
    )


185
186
187
188
def make_deepep_a2a(
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
189
190
191
192
    deepep_ht_args: DeepEPHTArgs | None,
    deepep_ll_args: DeepEPLLArgs | None,
    q_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
193
):
bnellnm's avatar
bnellnm committed
194
195
    if deepep_ht_args is not None:
        assert deepep_ll_args is None
196
197
198
        return make_deepep_ht_a2a(
            pg, pgi, dp_size, deepep_ht_args, q_dtype, block_shape
        )
bnellnm's avatar
bnellnm committed
199
200

    assert deepep_ll_args is not None
201
    return make_deepep_ll_a2a(pg, pgi, deepep_ll_args, q_dtype, block_shape)