async_worker.py 5.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
The async worker that transfers experts in the background.
"""

import asyncio
import threading
from typing import TYPE_CHECKING

import torch
from torch.distributed import ProcessGroup

14
from vllm.distributed.parallel_state import get_eplb_group
15
16
from vllm.logger import init_logger

17
18
from .eplb_utils import CpuGpuEvent
from .rebalance_execute import AsyncEplbLayerResult, transfer_layer
19
20

if TYPE_CHECKING:
21
    from .eplb_state import EplbModelState, EplbState
22
23
24
25
26
27
28
29

logger = init_logger(__name__)


def start_async_worker(
    state: "EplbState",
    is_profile: bool = False,
) -> threading.Thread:
30
31
    eplb_group = get_eplb_group().device_group
    rank = eplb_group.rank()
32
    device_index = state.cuda_device_index
33
    assert state.is_async
34
35
36

    def thread_target() -> None:
        assert device_index is not None
37
        torch.accelerator.set_device_index(device_index)
38
39
40
41
42
43
44
        cuda_stream = torch.cuda.Stream(device=device_index)
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        try:
            loop.run_until_complete(
                transfer_run_periodically(
                    state=state,
45
                    eplb_group=eplb_group,
46
                    cuda_stream=cuda_stream,
47
48
49
50
51
52
53
54
55
56
57
58
59
                    is_profile=is_profile,
                )
            )
        except Exception as exc:  # pragma: no cover - diagnostic path
            logger.exception("async loop error (Rank %d): %s", rank, str(exc))
        finally:
            loop.close()

    thread = threading.Thread(target=thread_target, daemon=True)
    thread.start()
    return thread


60
61
62
63
def run_rebalance_experts(
    model_state: "EplbModelState",
    eplb_state: "EplbState",
    physical_to_logical_map_cpu: torch.Tensor,
64
65
    cuda_stream: torch.cuda.Stream,
) -> torch.Tensor:
66
67
68
69
    assert model_state.eplb_stats is not None
    eplb_stats = model_state.eplb_stats

    # Move the global expert load window to CPU for computation.
70
71
    with torch.cuda.stream(cuda_stream):
        global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
72
    # Compute new expert mappings for the model
73
    new_physical_to_logical_map = eplb_state.policy.rebalance_experts(
74
75
76
77
78
79
80
81
82
        global_expert_load_window,
        eplb_stats.num_replicas,
        eplb_stats.num_groups,
        eplb_stats.num_nodes,
        eplb_stats.num_gpus,
        physical_to_logical_map_cpu,
    )
    assert new_physical_to_logical_map.device == torch.device("cpu")

83
    return new_physical_to_logical_map
84
85


86
87
async def transfer_run_periodically(
    state: "EplbState",
88
    eplb_group: ProcessGroup,
89
    cuda_stream: torch.cuda.Stream,
90
91
92
    is_profile: bool = False,
) -> None:
    while True:
93
        state.rearrange_event.wait(stream=cuda_stream)
94
95
        logger.info("async worker woke up for EPLB transfer")

96
        assert state.is_async
97
        for model_state in state.model_states.values():
98
            layer_idx = 0
99
100
            # Set the async worker's CUDA stream on the communicator
            model_state.communicator.set_stream(cuda_stream)
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
            num_layers = model_state.model.num_moe_layers

            # Snapshot the physical_to_logical_map (synchronized with
            # rearrange_event) and copy it to CPU
            with torch.cuda.stream(cuda_stream):
                physical_to_logical_map_cpu = model_state.physical_to_logical_map.cpu()

            new_physical_to_logical_map = run_rebalance_experts(
                model_state, state, physical_to_logical_map_cpu, cuda_stream
            )
            logger.info(
                "Async worker computed new indices for model %s",
                model_state.model_name,
            )

            # Execute one EPLB layer transfer per model forward pass. Each iteration
            # of this loop will copy the new set of expert weights into
            # model_state.expert_buffer, which will be consumed by the main thread in
            # move_to_workspace
            while model_state.rebalanced and layer_idx < num_layers:
121
                transfer_metadata = await transfer_layer(
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
                    old_layer_indices=physical_to_logical_map_cpu[layer_idx],
                    new_layer_indices=new_physical_to_logical_map[layer_idx],
                    expert_weights=model_state.model.expert_weights[layer_idx],
                    expert_weights_buffer=model_state.expert_buffer,
                    communicator=model_state.communicator,
                    ep_group=eplb_group,
                    is_profile=is_profile,
                    cuda_stream=cuda_stream,
                )

                # Wait until all writes to expert_buffer have finished before making the
                # AsyncEplbLayerResult visible to the main thread.
                cuda_stream.synchronize()

                # This event guarantees that expert_buffer will not be overwritten by
                # subsequent iterations of this loop until the main thread has consumed
                # it. Record is called by the main thread after move_from_buffer().
                consumed_event = CpuGpuEvent()

                model_state.pending_result = AsyncEplbLayerResult(
                    layer_idx=layer_idx,
                    new_physical_to_logical_map=new_physical_to_logical_map[layer_idx],
144
                    transfer_metadata=transfer_metadata,
145
146
147
148
149
150
151
152
153
154
                    consumed_event=consumed_event,
                )

                # Block this thread until the main thread and main stream
                # finish copying model_state.expert_buffer into
                # model_state.model.expert_weights[layer_idx]
                consumed_event.wait(stream=cuda_stream)
                logger.debug("Layer %d transfer complete", layer_idx)
                assert model_state.pending_result is None
                layer_idx += 1