async_worker.py 6.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
The async worker that transfers experts in the background.
"""

import asyncio
import threading
from typing import TYPE_CHECKING

import torch
from torch.distributed import ProcessGroup

14
from vllm.distributed.parallel_state import get_eplb_group
15
16
17
18
19
from vllm.logger import init_logger

from .rebalance_execute import transfer_layer

if TYPE_CHECKING:
20
    from .eplb_state import EplbModelState, EplbState
21
22
23
24
25
26
27
28

logger = init_logger(__name__)


def start_async_worker(
    state: "EplbState",
    is_profile: bool = False,
) -> threading.Thread:
29
30
    eplb_group = get_eplb_group().device_group
    rank = eplb_group.rank()
31
    device_index = state.cuda_device_index
32
    assert state.is_async
33
34
35

    def thread_target() -> None:
        assert device_index is not None
36
        torch.accelerator.set_device_index(device_index)
37
38
39
40
41
42
43
        cuda_stream = torch.cuda.Stream(device=device_index)
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        try:
            loop.run_until_complete(
                transfer_run_periodically(
                    state=state,
44
                    eplb_group=eplb_group,
45
                    cuda_stream=cuda_stream,
46
47
48
49
50
51
52
53
54
55
56
57
58
                    is_profile=is_profile,
                )
            )
        except Exception as exc:  # pragma: no cover - diagnostic path
            logger.exception("async loop error (Rank %d): %s", rank, str(exc))
        finally:
            loop.close()

    thread = threading.Thread(target=thread_target, daemon=True)
    thread.start()
    return thread


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def run_rebalance_experts(
    model_state: "EplbModelState",
    eplb_state: "EplbState",
    physical_to_logical_map_cpu: torch.Tensor,
) -> None:
    assert model_state.eplb_stats is not None
    eplb_stats = model_state.eplb_stats

    # Wait for the main thread's all-reduce and clone to complete before
    # accessing the global_expert_load_window tensor.
    assert model_state.window_ready_event is not None
    model_state.window_ready_event.wait()
    model_state.window_ready_event = None

    # Move the global expert load window to CPU for computation.
    global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
    # Compute new expert mappings for the model
76
    new_physical_to_logical_map = eplb_state.policy.rebalance_experts(
77
78
79
80
81
82
83
84
85
86
87
88
        global_expert_load_window,
        eplb_stats.num_replicas,
        eplb_stats.num_groups,
        eplb_stats.num_nodes,
        eplb_stats.num_gpus,
        physical_to_logical_map_cpu,
    )
    assert new_physical_to_logical_map.device == torch.device("cpu")

    model_state.new_physical_to_logical_map = new_physical_to_logical_map


89
90
async def transfer_run_periodically(
    state: "EplbState",
91
    eplb_group: ProcessGroup,
92
    cuda_stream: torch.cuda.Stream,
93
94
95
96
97
98
    is_profile: bool = False,
) -> None:
    while True:
        await asyncio.to_thread(state.rearrange_event.wait)
        logger.info("async worker woke up for EPLB transfer")

99
        assert state.is_async
100
        for model_state in state.model_states.values():
101
102
            rebalancing_algorithm_executed = False
            physical_to_logical_map_cpu = None
103
104
105
106
107
            current_num_layers = model_state.model.num_moe_layers
            while (
                model_state.rebalanced
                and model_state.layer_to_transfer < current_num_layers
            ):
108
109
110
111
112
113
                if not model_state.ep_buffer_ready and model_state.rebalanced:
                    # Polling the lock directly in the async thread avoids
                    # the thread switch overhead of asyncio.to_thread.
                    # This is typically faster than offloading to a worker thread.
                    while not model_state.buffer_lock.acquire(blocking=False):
                        await asyncio.sleep(0)
114
115
116
                    try:
                        if model_state.layer_to_transfer >= current_num_layers:
                            break
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
                        if (
                            not rebalancing_algorithm_executed
                            or model_state.new_physical_to_logical_map is None
                        ):
                            # Move the physical_to_logical_map to CPU
                            # for rebalancing and transfer_layer.
                            physical_to_logical_map_cpu = (
                                model_state.physical_to_logical_map.cpu()
                            )
                            run_rebalance_experts(
                                model_state, state, physical_to_logical_map_cpu
                            )
                            rebalancing_algorithm_executed = True
                            logger.info(
                                "Async worker computed new indices for model %s",
                                model_state.model_name,
                            )

                        assert model_state.new_physical_to_logical_map is not None
                        assert physical_to_logical_map_cpu is not None

                        layer_idx = model_state.layer_to_transfer
                        old_layer_indices = physical_to_logical_map_cpu[layer_idx]
                        new_layer_indices = model_state.new_physical_to_logical_map[
                            layer_idx
                        ]
143

144
                        # Wait for the main thread to finish consuming the buffer
145
                        # before initiating an EPLB transfer on another layer.
146
147
148
149
                        if model_state.buffer_consumed_event is not None:
                            cuda_stream.wait_event(model_state.buffer_consumed_event)
                            model_state.buffer_consumed_event = None

150
151
152
                        (
                            model_state.is_unchanged,
                            model_state.is_received_locally,
153
                            model_state.recv_metadata,
154
                        ) = await transfer_layer(
155
156
157
                            old_layer_indices=old_layer_indices,
                            new_layer_indices=new_layer_indices,
                            expert_weights=model_state.model.expert_weights[layer_idx],
158
                            expert_weights_buffer=model_state.expert_buffer,
159
                            ep_group=eplb_group,
160
161
162
                            is_profile=is_profile,
                            cuda_stream=cuda_stream,
                        )
163
164
165
                        # block the async thread until the transfer to
                        # the intermediate buffer is complete.
                        cuda_stream.synchronize()
166
167
168
169
170
171
172
173
174
                        model_state.ep_buffer_ready = 1
                    finally:
                        model_state.buffer_lock.release()
                else:
                    if not model_state.rebalanced:
                        break
                    await asyncio.sleep(0.001)

        state.rearrange_event.clear()