eplb_utils.py 4.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Callable
from functools import wraps
from typing import Any

import torch
import torch.nn as nn

from vllm.distributed.eplb.eplb_state import EplbState
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import is_mixture_of_experts

logger = init_logger(__name__)


def step_eplb_after(*, is_dummy: bool = False) -> Callable:
    """Step EPLB after a model runner method completes successfully."""

    def decorator(fn: Callable) -> Callable:
        @wraps(fn)
        def wrapper(self: Any, *args, **kwargs) -> Any:
            result = fn(self, *args, **kwargs)
            if kwargs.get("skip_eplb", False):
                return result

            is_profile = kwargs.get("is_profile", False) if is_dummy else False
            self.eplb.step(is_dummy=is_dummy, is_profile=is_profile)
            return result

        return wrapper

    return decorator


class EPLBController:
    def __init__(self, parallel_config: Any, device: torch.device):
        self.parallel_config = parallel_config
        self.device = device
        self.state: EplbState | None = None
        self.suppressed = False
        self._has_registered_models = False

    def prepare_load(self) -> None:
        self.state = None
        self._has_registered_models = False
        if self.parallel_config.enable_eplb:
            self.state = EplbState(self.parallel_config, self.device)

    def maybe_register_speculator(
        self,
        speculator: Any | None,
        speculative_config: Any | None,
        load_dummy_weights: bool,
    ) -> bool:
        # if speculator is a moe model, add it to eplb
        if (
            speculator is None
            or not hasattr(speculator, "model")
            or not self.parallel_config.enable_eplb
            or load_dummy_weights
        ):
            return False

        draft_model = speculator.model
        if not is_mixture_of_experts(draft_model):
            return False

        assert not self.parallel_config.enable_elastic_ep, (
            "Elastic EP is not supported with draft model."
        )
        assert speculative_config is not None
        assert speculative_config.draft_model_config is not None
        assert self.state is not None
        self.state.add_model(
            draft_model,
            speculative_config.draft_model_config,
        )
        self._has_registered_models = True
        return True

    def maybe_register_model(
        self,
        model: nn.Module,
        model_config: Any,
        load_dummy_weights: bool,
    ) -> bool:
        if not self.parallel_config.enable_eplb or load_dummy_weights:
            return False

        if not is_mixture_of_experts(model):
            return False

95
        logger.info_once("EPLB is enabled for model %s.", model_config.model)
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        assert self.state is not None
        self.state.add_model(model, model_config)
        self._has_registered_models = True
        return True

    def maybe_start_async_loop(self, eplb_models_added: bool) -> None:
        if eplb_models_added and self.state is not None and self.state.is_async:
            self.state.start_async_loop()

    def step(
        self,
        is_dummy: bool = False,
        is_profile: bool = False,
    ) -> None:
        if (
            not self.parallel_config.enable_eplb
            or self.suppressed
            or self.state is None
            or not self._has_registered_models
        ):
            return

        self.state.step(
            is_dummy,
            is_profile,
            log_stats=self.parallel_config.eplb_config.log_balancedness,
        )

    def setup_from_mapping(
        self,
        model: nn.Module,
        model_config: Any,
        expanded_physical_to_logical: torch.Tensor,
        old_num_physical_experts: int,
    ) -> None:
        assert is_mixture_of_experts(model)

        self.state = EplbState.from_mapping(
            model=model,
            model_config=model_config,
            device=self.device,
            parallel_config=self.parallel_config,
            expanded_physical_to_logical=expanded_physical_to_logical,
            num_valid_physical_experts=old_num_physical_experts,
        )
        self._has_registered_models = True