async_worker.py 7.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
The async worker that transfers experts in the background.
"""

import asyncio
import threading
from typing import TYPE_CHECKING

import torch
from torch.distributed import ProcessGroup

14
from vllm.distributed.parallel_state import get_eplb_group
15
16
17
18
19
from vllm.logger import init_logger

from .rebalance_execute import transfer_layer

if TYPE_CHECKING:
20
    from .eplb_state import EplbModelState, EplbState
21
22
23
24
25
26
27
28

logger = init_logger(__name__)


def start_async_worker(
    state: "EplbState",
    is_profile: bool = False,
) -> threading.Thread:
29
30
    eplb_group = get_eplb_group().device_group
    rank = eplb_group.rank()
31
    device_index = state.cuda_device_index
32
    assert state.is_async
33
34
35
36
37
38
39
40
41
42
43

    def thread_target() -> None:
        assert device_index is not None
        torch.cuda.set_device(device_index)
        cuda_stream = torch.cuda.Stream(device=device_index)
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        try:
            loop.run_until_complete(
                transfer_run_periodically(
                    state=state,
44
                    eplb_group=eplb_group,
45
                    cuda_stream=cuda_stream,
46
47
48
49
50
51
52
53
54
55
56
57
58
                    is_profile=is_profile,
                )
            )
        except Exception as exc:  # pragma: no cover - diagnostic path
            logger.exception("async loop error (Rank %d): %s", rank, str(exc))
        finally:
            loop.close()

    thread = threading.Thread(target=thread_target, daemon=True)
    thread.start()
    return thread


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
def run_rebalance_experts(
    model_state: "EplbModelState",
    eplb_state: "EplbState",
    physical_to_logical_map_cpu: torch.Tensor,
) -> None:
    assert model_state.eplb_stats is not None
    eplb_stats = model_state.eplb_stats

    # Wait for the main thread's all-reduce and clone to complete before
    # accessing the global_expert_load_window tensor.
    assert model_state.window_ready_event is not None
    model_state.window_ready_event.wait()
    model_state.window_ready_event = None

    # Move the global expert load window to CPU for computation.
    global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
    # Compute new expert mappings for the model
    (
        new_physical_to_logical_map,
        new_logical_to_physical_map,
        new_logical_replica_count,
    ) = eplb_state.policy.rebalance_experts(
        global_expert_load_window,
        eplb_stats.num_replicas,
        eplb_stats.num_groups,
        eplb_stats.num_nodes,
        eplb_stats.num_gpus,
        physical_to_logical_map_cpu,
    )
    assert new_physical_to_logical_map.device == torch.device("cpu")

    model_state.new_physical_to_logical_map = new_physical_to_logical_map

    max_slots = model_state.logical_to_physical_map.shape[-1]
    padded_logical = torch.nn.functional.pad(
        new_logical_to_physical_map,
        (0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
        value=-1,
    ).to(model_state.logical_to_physical_map.device)
    new_replica = new_logical_replica_count.to(model_state.logical_replica_count.device)
    model_state.new_logical_to_physical_map = padded_logical
    model_state.new_logical_replica_count = new_replica


103
104
async def transfer_run_periodically(
    state: "EplbState",
105
    eplb_group: ProcessGroup,
106
    cuda_stream: torch.cuda.Stream,
107
108
109
110
111
112
    is_profile: bool = False,
) -> None:
    while True:
        await asyncio.to_thread(state.rearrange_event.wait)
        logger.info("async worker woke up for EPLB transfer")

113
        assert state.is_async
114
        for model_state in state.model_states.values():
115
116
            rebalancing_algorithm_executed = False
            physical_to_logical_map_cpu = None
117
118
119
120
121
            current_num_layers = model_state.model.num_moe_layers
            while (
                model_state.rebalanced
                and model_state.layer_to_transfer < current_num_layers
            ):
122
123
124
125
126
127
                if not model_state.ep_buffer_ready and model_state.rebalanced:
                    # Polling the lock directly in the async thread avoids
                    # the thread switch overhead of asyncio.to_thread.
                    # This is typically faster than offloading to a worker thread.
                    while not model_state.buffer_lock.acquire(blocking=False):
                        await asyncio.sleep(0)
128
129
130
                    try:
                        if model_state.layer_to_transfer >= current_num_layers:
                            break
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
                        if (
                            not rebalancing_algorithm_executed
                            or model_state.new_physical_to_logical_map is None
                        ):
                            # Move the physical_to_logical_map to CPU
                            # for rebalancing and transfer_layer.
                            physical_to_logical_map_cpu = (
                                model_state.physical_to_logical_map.cpu()
                            )
                            run_rebalance_experts(
                                model_state, state, physical_to_logical_map_cpu
                            )
                            rebalancing_algorithm_executed = True
                            logger.info(
                                "Async worker computed new indices for model %s",
                                model_state.model_name,
                            )

                        assert model_state.new_physical_to_logical_map is not None
                        assert physical_to_logical_map_cpu is not None

                        layer_idx = model_state.layer_to_transfer
                        old_layer_indices = physical_to_logical_map_cpu[layer_idx]
                        new_layer_indices = model_state.new_physical_to_logical_map[
                            layer_idx
                        ]
157

158
                        # Wait for the main thread to finish consuming the buffer
159
                        # before initiating an EPLB transfer on another layer.
160
161
162
163
                        if model_state.buffer_consumed_event is not None:
                            cuda_stream.wait_event(model_state.buffer_consumed_event)
                            model_state.buffer_consumed_event = None

164
165
166
                        (
                            model_state.is_unchanged,
                            model_state.is_received_locally,
167
                            model_state.recv_metadata,
168
                        ) = await transfer_layer(
169
170
171
                            old_layer_indices=old_layer_indices,
                            new_layer_indices=new_layer_indices,
                            expert_weights=model_state.model.expert_weights[layer_idx],
172
                            expert_weights_buffer=model_state.expert_buffer,
173
                            ep_group=eplb_group,
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
                            is_profile=is_profile,
                            cuda_stream=cuda_stream,
                        )
                        event = torch.cuda.Event(blocking=False)
                        cuda_stream.record_event(event)
                        model_state.buffer_ready_event = event
                        model_state.ep_buffer_ready = 1
                    finally:
                        model_state.buffer_lock.release()
                else:
                    if not model_state.rebalanced:
                        break
                    await asyncio.sleep(0.001)

        state.rearrange_event.clear()