medusa.py 7.09 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from collections.abc import Iterable
6
7
8
9

import torch
import torch.nn as nn

10
from vllm.config import VllmConfig
11
12
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
13
14
    ParallelLMHead,
)
15
16
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

17
18
from .utils import maybe_prefix

19
20
from vllm import _custom_ops as ops

21
22

class ResidualBlock(nn.Module):
23
    def __init__(self, config: VllmConfig, hidden_size: int, num_layers: int) -> None:
24
25
        super().__init__()

26
27
28
29
30
31
32
33
34
35
        self.layers = nn.ModuleList(
            [
                nn.Linear(
                    hidden_size,
                    hidden_size,
                    bias=getattr(config, "medusa_fc_bias", False),
                )
                for _ in range(num_layers)
            ]
        )
36
37
38
39
40
41
42
43
44
        self.act = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = x + self.act(layer(x))
        return x


class Medusa(nn.Module):
45
46
    """This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774
    Reference implementation: https://github.com/FasterDecoding/Medusa
47

48
49
    Differences from reference implementation:
    1. Currently this only supports generating proposals from top-1 tokens.
50
51
52
53
    2. We have an optional token_map which reduces draft vocab to most
       frequently used tokens to give some additional speed-up by reducing
       sampling overhead. This is disabled unless the checkpoint file has
       explicit token_map tensor and config has an optional attribute
54
55
56
57
       truncated_vocab_size < vocab_size. To use this technique, one has to find
       the top-k most frequent tokens in target dataset and add that as a tensor
       in the draft checkpoint (using key token_map). Also, the draft config
       needs to have truncated_vocab_size (=k) as an attribute."""
58

59
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
60
        config = vllm_config.speculative_config.draft_model_config.hf_config
61
        super().__init__()
62
63

        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
64
        self.config = config
65
66
67
68
69
70
71
72
73
74
        self.blocks = nn.ModuleList(
            [
                ResidualBlock(
                    config=config,
                    hidden_size=self.config.hidden_size,
                    num_layers=self.config.num_hidden_layers,
                )
                for _ in range(self.config.num_heads)
            ]
        )
75
76
77
        self.orig_vocab_size = config.vocab_size
        self.truncated_vocab_size = config.truncated_vocab_size

78
79
        if getattr(config, "original_lm_head", False):
            self.lm_head = ParallelLMHead(
80
                self.truncated_vocab_size,
81
                config.hidden_size,
82
                prefix=maybe_prefix(prefix, "lm_head"),
83
            )
84
            self.lm_heads = [self.lm_head for _ in range(self.config.num_heads)]
85
        else:
86
87
88
            self.lm_heads = nn.ModuleList(
                [
                    ParallelLMHead(
89
                        config.vocab_size,
90
91
92
93
94
95
                        config.hidden_size,
                        prefix=maybe_prefix(prefix, f"lm_heads.{i}"),
                    )
                    for i in range(self.config.num_heads)
                ]
            )
96
97

        logit_scale = getattr(config, "logit_scale", 1.0)
98
        self.logits_processor = LogitsProcessor(
99
            config.vocab_size, self.truncated_vocab_size, logit_scale
100
        )
101

102
103
104
105
106
107
        # Token map is a idx to token mapping to reduce the vocab size for
        # the draft model. Using smaller vocab size for draft, containing
        # only most frequent tokens reduces the speculation overhead. This
        # doesn't affect the acceptance rate much and thus gives more speed
        # -up. By default, this is disabled and is only used if the EAGLE
        # checkpoint file has token_map tensor.
108
109
        self.token_map = None

110
    def forward(self, hidden_states: torch.Tensor) -> list[torch.Tensor]:
111
112
113
        return [block(hidden_states) for block in self.blocks]

    def compute_logits(
114
115
116
        self,
        hidden_states: list[torch.Tensor],
    ) -> list[torch.Tensor]:
117
        logits_lst: list[torch.Tensor] = []
118
119

        for hs, lm_head in zip(hidden_states, self.lm_heads):
120
            _logits = self.logits_processor(lm_head, hs)
121

122
123
124
125
126
127
            if _logits is None:
                # _logits should only be None on rank > 0, in which case
                # it should remain true for every lm_head
                assert len(logits_lst) == 0
                continue

128
            if self.token_map is None:
129
                logits_lst.append(_logits)
130
            else:
131
132
133
134
135
136
137
138
                logits_lst.append(
                    -torch.inf
                    * torch.ones(
                        size=(*_logits.shape[:-1], self.orig_vocab_size),
                        device=_logits.device,
                        dtype=_logits.dtype,
                    )
                )
139

140
                logits_lst[-1][..., self.token_map] = _logits
141

142
        return logits_lst
143

144
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
145
        params_dict = dict(self.named_parameters())
146
        loaded_params: set[str] = set()
147
148
149
150
151
152
153
154

        weights_map = {}

        for name, loaded_weight in weights:
            name = name.replace("medusa_heads.", "")

            if name == "token_map":
                if self.truncated_vocab_size < self.orig_vocab_size:
155
                    self.token_map = nn.Parameter(loaded_weight, requires_grad=False)
156
157
            elif name in params_dict:
                weights_map[name] = loaded_weight
158
159
160
161
            elif (
                getattr(self.config, "original_lm_head", False)
                and name == "lm_heads.0.weight"
            ):
162
                weights_map["lm_head.weight"] = loaded_weight
163
164

        for name, loaded_weight in weights_map.items():
165
166
167
168
169
            if (
                "lm_head" in name
                and self.token_map is not None
                and loaded_weight.shape[0] > self.token_map.shape[0]
            ):
170
171
172
                loaded_weight = loaded_weight[self.token_map]

            param = params_dict[name]
173
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
174
            weight_loader(param, loaded_weight)
175
            loaded_params.add(name)
176

177
            if self.use_llama_nn and os.environ['LM_NN'] == '1' and "lm_head" in name:
178
179
180
181
182
183
184
185
                _weight = torch.zeros_like(param.data)
                ori_shape =_weight.shape
                
                ops.trans_w16_gemm(_weight, param.data, _weight.shape[0], _weight.shape[1])
                param.data.copy_(_weight)
                
                param.data=param.data.reshape(ori_shape[1],-1)

186
187
188
        if self.token_map is not None:
            self.token_map.to(device=self.lm_heads[0].weight.device)

189
190
191
        assert (self.truncated_vocab_size == self.orig_vocab_size) or (
            self.token_map is not None
        )
192
193

        return loaded_params