medusa.py 8.37 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
from collections.abc import Iterable
from typing import Optional
5
6
7
8

import torch
import torch.nn as nn

9
from vllm.config import VllmConfig
10
from vllm.model_executor.layers.logits_processor import LogitsProcessor
11
from vllm.model_executor.layers.sampler import SamplerOutput
12
13
14
15
16
17
18
19
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata


class ResidualBlock(nn.Module):

20
21
    def __init__(self, config: VllmConfig, hidden_size: int,
                 num_layers: int) -> None:
22
23
24
        super().__init__()

        self.layers = nn.ModuleList([
25
26
27
            nn.Linear(hidden_size,
                      hidden_size,
                      bias=getattr(config, "medusa_fc_bias", False))
28
29
30
31
32
33
34
35
36
37
38
            for _ in range(num_layers)
        ])
        self.act = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = x + self.act(layer(x))
        return x


class Medusa(nn.Module):
39
40
41
42
43
44
45
46
47
48
49
50
51
    """This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774
    Reference implementation: https://github.com/FasterDecoding/Medusa
    
    Differences from reference implementation:
    1. Currently this only supports generating proposals from top-1 tokens.
    2. We have an optional token_map which reduces draft vocab to most 
       frequently used tokens to give some additional speed-up by reducing 
       sampling overhead. This is disabled unless the checkpoint file has 
       explicit token_map tensor and config has an optional attribute 
       truncated_vocab_size < vocab_size. To use this technique, one has to find
       the top-k most frequent tokens in target dataset and add that as a tensor
       in the draft checkpoint (using key token_map). Also, the draft config
       needs to have truncated_vocab_size (=k) as an attribute."""
52

53
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
54
        config = vllm_config.model_config.hf_config
55
56
57
        super().__init__()
        self.config = config
        self.blocks = nn.ModuleList([
58
59
            ResidualBlock(config=config,
                          hidden_size=self.config.hidden_size,
60
61
62
63
64
65
66
                          num_layers=self.config.num_hidden_layers)
            for _ in range(self.config.num_heads)
        ])
        self.orig_vocab_size = config.vocab_size
        self.truncated_vocab_size = config.truncated_vocab_size
        self.unpadded_vocab_size = self.truncated_vocab_size

67
68
        if getattr(config, "original_lm_head", False):
            self.lm_head = ParallelLMHead(
69
70
71
72
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=self.truncated_vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE,
73
74
75
76
77
78
79
80
81
82
83
84
85
            )
            self.lm_heads = [
                self.lm_head for _ in range(self.config.num_heads)
            ]
        else:
            self.lm_heads = nn.ModuleList([
                ParallelLMHead(
                    self.unpadded_vocab_size,
                    config.hidden_size,
                    org_num_embeddings=self.truncated_vocab_size,
                    padding_size=DEFAULT_VOCAB_PADDING_SIZE,
                ) for _ in range(self.config.num_heads)
            ])
86
87
88
89
90
91

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                self.truncated_vocab_size,
                                                logit_scale)

92
93
94
95
96
97
        # Token map is a idx to token mapping to reduce the vocab size for
        # the draft model. Using smaller vocab size for draft, containing
        # only most frequent tokens reduces the speculation overhead. This
        # doesn't affect the acceptance rate much and thus gives more speed
        # -up. By default, this is disabled and is only used if the EAGLE
        # checkpoint file has token_map tensor.
98
99
        self.token_map = None

100
    def forward(self, hidden_states: torch.Tensor) -> list[torch.Tensor]:
101
102
103
        return [block(hidden_states) for block in self.blocks]

    def compute_logits(
104
105
106
            self, hidden_states: list[torch.Tensor],
            sampling_metadata: SamplingMetadata) -> list[torch.Tensor]:
        logits_lst: list[torch.Tensor] = []
107
108
109
110

        for hs, lm_head in zip(hidden_states, self.lm_heads):
            _logits = self.logits_processor(lm_head, hs, sampling_metadata)

111
112
113
114
115
116
            if _logits is None:
                # _logits should only be None on rank > 0, in which case
                # it should remain true for every lm_head
                assert len(logits_lst) == 0
                continue

117
            if self.token_map is None:
118
                logits_lst.append(_logits)
119
            else:
120
                logits_lst.append(-torch.inf * torch.ones(
121
122
123
124
                    size=(*_logits.shape[:-1], self.orig_vocab_size),
                    device=_logits.device,
                    dtype=_logits.dtype))

125
                logits_lst[-1][..., self.token_map] = _logits
126

127
        return logits_lst
128
129
130

    def sample(
        self,
131
        logits: list[torch.Tensor],
132
        sampling_metadata: SamplingMetadata,
133
    ) -> list[SamplerOutput]:
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        logits = torch.stack(logits, dim=0).float()
        logprobs = torch.log_softmax(logits, dim=-1)
        token_ids = logits.argmax(-1)  # support only top-1 for now
        probs = torch.softmax(logits, dim=-1)

        token_id_list = []
        token_prob_list = []
        token_logprob_list = []

        for idx, seq_group in enumerate(sampling_metadata.seq_groups):
            token_id_list.append(token_ids[:, seq_group.sample_indices])
            token_prob_list.append(probs[:, seq_group.sample_indices])
            token_logprob_list.append(logprobs[:, seq_group.sample_indices])

148
        outputs: list[Optional[SamplerOutput]] = []
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        for idx in range(len(sampling_metadata.seq_groups)):
            outputs.append(
                SamplerOutput(
                    outputs=None,
                    sampled_token_probs=token_prob_list[idx].squeeze(1),
                    logprobs=token_logprob_list[idx].squeeze(1),
                    sampled_token_ids=token_id_list[idx].squeeze(1),
                ))

        return outputs

    def generate_proposals(
        self,
        previous_hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
164
    ) -> list[SamplerOutput]:
165
166
167
168
169
170
171
172
        return self.sample(
            logits=self.compute_logits(
                hidden_states=self.forward(previous_hidden_states),
                sampling_metadata=sampling_metadata,
            ),
            sampling_metadata=sampling_metadata,
        )

173
174
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
175
        params_dict = dict(self.named_parameters())
176
        loaded_params: set[str] = set()
177
178
179
180
181
182
183
184
185
186
187
188

        weights_map = {}

        for name, loaded_weight in weights:
            name = name.replace("medusa_heads.", "")

            if name == "token_map":
                if self.truncated_vocab_size < self.orig_vocab_size:
                    self.token_map = nn.Parameter(loaded_weight,
                                                  requires_grad=False)
            elif name in params_dict:
                weights_map[name] = loaded_weight
189
190
191
            elif (getattr(self.config, "original_lm_head", False)
                  and name == "lm_heads.0.weight"):
                weights_map["lm_head.weight"] = loaded_weight
192
193
194
195
196
197
198
199
200
201
202

        for name, loaded_weight in weights_map.items():
            if "lm_head" in name and self.token_map is not None and\
                loaded_weight.shape[0] > self.token_map.shape[0]:

                loaded_weight = loaded_weight[self.token_map]

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
203
            loaded_params.add(name)
204
205
206
207
208
209

        if self.token_map is not None:
            self.token_map.to(device=self.lm_heads[0].weight.device)

        assert (self.truncated_vocab_size
                == self.orig_vocab_size) or (self.token_map is not None)
210
211

        return loaded_params