medusa.py 10.9 KB
Newer Older
1
2
import os
from typing import Iterable, List, Optional, Tuple, Any, Dict
3
4
5
6
7

import torch
import torch.nn as nn

from vllm.model_executor.layers.logits_processor import LogitsProcessor
8
from vllm.model_executor.layers.sampler import SamplerOutput
9
10
11
12
13
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.transformers_utils.configs.medusa import MedusaConfig
14
from vllm import _custom_ops as ops
15
16
from vllm.distributed import (tensor_model_parallel_all_gather,
                              tensor_model_parallel_gather)
17
18

TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
19
20
21
22
23
24
25
26


class ResidualBlock(nn.Module):

    def __init__(self, hidden_size: int, num_layers: int) -> None:
        super().__init__()

        self.layers = nn.ModuleList([
27
            nn.Linear(hidden_size, hidden_size)
28
29
30
31
32
33
34
35
36
37
38
            for _ in range(num_layers)
        ])
        self.act = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = x + self.act(layer(x))
        return x


class Medusa(nn.Module):
39
40
41
42
43
44
45
46
47
48
49
50
51
    """This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774
    Reference implementation: https://github.com/FasterDecoding/Medusa
    
    Differences from reference implementation:
    1. Currently this only supports generating proposals from top-1 tokens.
    2. We have an optional token_map which reduces draft vocab to most 
       frequently used tokens to give some additional speed-up by reducing 
       sampling overhead. This is disabled unless the checkpoint file has 
       explicit token_map tensor and config has an optional attribute 
       truncated_vocab_size < vocab_size. To use this technique, one has to find
       the top-k most frequent tokens in target dataset and add that as a tensor
       in the draft checkpoint (using key token_map). Also, the draft config
       needs to have truncated_vocab_size (=k) as an attribute."""
52
53
54

    def __init__(self, config: MedusaConfig, **_) -> None:
        super().__init__()
55
56

        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
57
58
59
60
61
62
63
64
65
        self.config = config
        self.blocks = nn.ModuleList([
            ResidualBlock(hidden_size=self.config.hidden_size,
                          num_layers=self.config.num_hidden_layers)
            for _ in range(self.config.num_heads)
        ])
        self.orig_vocab_size = config.vocab_size
        self.truncated_vocab_size = config.truncated_vocab_size
        self.unpadded_vocab_size = self.truncated_vocab_size
66
        self.medusa_choices = config.medusa_choices
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

        self.lm_heads = nn.ModuleList([
            ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=self.truncated_vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE,
            ) for _ in range(self.config.num_heads)
        ])

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                self.truncated_vocab_size,
                                                logit_scale)

82
83
84
85
86
87
        # Token map is a idx to token mapping to reduce the vocab size for
        # the draft model. Using smaller vocab size for draft, containing
        # only most frequent tokens reduces the speculation overhead. This
        # doesn't affect the acceptance rate much and thus gives more speed
        # -up. By default, this is disabled and is only used if the EAGLE
        # checkpoint file has token_map tensor.
88
89
90
91
92
93
        self.token_map = None

    def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
        return [block(hidden_states) for block in self.blocks]

    def compute_logits(
94
            self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
95
        logits_lst: List[torch.Tensor] = []
96
97

        for hs, lm_head in zip(hidden_states, self.lm_heads):
98
99
100
            #_logits = self.logits_processor(lm_head, hs, sampling_metadata)
            _logits = lm_head.linear_method.apply(lm_head, hs, bias=None)
            _logits = tensor_model_parallel_all_gather(_logits)
101

102
103
104
105
106
107
            if _logits is None:
                # _logits should only be None on rank > 0, in which case
                # it should remain true for every lm_head
                assert len(logits_lst) == 0
                continue

108
            if self.token_map is None:
109
                logits_lst.append(_logits)
110
            else:
111
                logits_lst.append(-torch.inf * torch.ones(
112
113
114
115
                    size=(*_logits.shape[:-1], self.orig_vocab_size),
                    device=_logits.device,
                    dtype=_logits.dtype))

116
                logits_lst[-1][..., self.token_map] = _logits
117

118
        return logits_lst
119
120
121
122

    def sample(
        self,
        logits: List[torch.Tensor],
123
        sample_indices_list: List[List[int]],
124
125
126
127
128
129
130
131
132
133
    ) -> List[SamplerOutput]:
        logits = torch.stack(logits, dim=0).float()
        logprobs = torch.log_softmax(logits, dim=-1)
        token_ids = logits.argmax(-1)  # support only top-1 for now
        probs = torch.softmax(logits, dim=-1)

        token_id_list = []
        token_prob_list = []
        token_logprob_list = []

134
135
136
137
        for idx, sample_indices in enumerate(sample_indices_list):
            token_id_list.append(token_ids[:, sample_indices])
            token_prob_list.append(probs[:, sample_indices])
            token_logprob_list.append(logprobs[:, sample_indices])
138
139

        outputs: List[Optional[SamplerOutput]] = []
140
        for idx in range(len(sample_indices_list)):
141
142
143
144
145
146
147
148
149
            outputs.append(
                SamplerOutput(
                    outputs=None,
                    sampled_token_probs=token_prob_list[idx].squeeze(1),
                    logprobs=token_logprob_list[idx].squeeze(1),
                    sampled_token_ids=token_id_list[idx].squeeze(1),
                ))

        return outputs
150
151
152
153
    
    def medusa_sample(
        self,
        medusa_logits: List[torch.Tensor],
154
        sample_indices_list: List[List[int]],
155
156
157
158
159
        logits: torch.Tensor,
        medusa_buffers: Dict[str, Any]
    ) -> List[SamplerOutput]:
        batch_size = logits.shape[0]
        candidates_logit = torch.argmax(logits, dim=-1).view(batch_size, -1) # [batch_size, 1]
160

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        medusa_logits = torch.stack(medusa_logits, dim=0) # [medusa_heads, batch_size, vocab_size]

        # Extract the TOPK candidates from the medusa logits.
        candidates_medusa_logits = torch.topk(medusa_logits, TOPK, dim=-1).indices # [medusa_heads, batch_size, TOPK]
        candidates_medusa_logits = candidates_medusa_logits.permute(1, 0 ,2) # [batch_size, medusa_heads, TOPK]
        candidates_medusa_logits = candidates_medusa_logits.reshape(batch_size, -1)

        # Combine the selected candidate from the original logits with the topk medusa logits.
        candidates = torch.cat([candidates_logit, candidates_medusa_logits], dim=-1)  #[batch_size, 1+medusa_heads*TOPK] 

        # Map the combined candidates to the tree indices to get tree candidates.
        tree_candidates = torch.index_select(candidates, dim=-1, index=medusa_buffers['tree_indices']) # [batch_size, choices]

        # Extend the tree candidates by appending a zero.
        tree_candidates_ext = torch.cat([tree_candidates, torch.zeros((batch_size, 1),
                                                    dtype=torch.long, device=tree_candidates.device)], dim=-1) # [batch_size, choices]

        # Retrieve the cartesian candidates using the retrieve indices.
        cart_candidates = tree_candidates_ext[:, medusa_buffers['retrieve_indices']] # [batch_size, retrieve_size, max_depth]

        token_id_list = []
        cart_candidate_list = []

184
185
186
        for sample_indices in sample_indices_list:
            token_id_list.append(tree_candidates[sample_indices, :])
            cart_candidate_list.append(cart_candidates[sample_indices, :])
187
188

        outputs: List[Optional[SamplerOutput]] = []
189
        for idx in range(len(sample_indices_list)):
190
191
192
193
194
195
196
197
            outputs.append(
                SamplerOutput(
                    outputs=None,
                    sampled_token_ids=token_id_list[idx].squeeze(1),
                    cart_candidates=cart_candidate_list[idx]
                ))

        return outputs
198
199
200
201

    def generate_proposals(
        self,
        previous_hidden_states: torch.Tensor,
202
        sample_indices_list: List[List[int]],
203
204
        previous_logits: torch.Tensor=None,
        medusa_buffers: Dict[str, Any]=None
205
    ) -> List[SamplerOutput]:
206
207
208
        if medusa_buffers is None:
            return self.sample(
                logits=self.compute_logits(
209
                    hidden_states=self.forward(previous_hidden_states)
210
                ),
211
                sample_indices_list=sample_indices_list,
212
213
214
            )
        else:
            return self.medusa_sample(medusa_logits=self.compute_logits(
215
                    hidden_states=self.forward(previous_hidden_states)
216
                ),
217
                sample_indices_list=sample_indices_list,
218
219
                logits=previous_logits,
                medusa_buffers=medusa_buffers)
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        params_dict = dict(self.named_parameters())

        weights_map = {}

        for name, loaded_weight in weights:
            name = name.replace("medusa_heads.", "")

            if name == "token_map":
                if self.truncated_vocab_size < self.orig_vocab_size:
                    self.token_map = nn.Parameter(loaded_weight,
                                                  requires_grad=False)
            elif name in params_dict:
                weights_map[name] = loaded_weight

        for name, loaded_weight in weights_map.items():
            if "lm_head" in name and self.token_map is not None and\
                loaded_weight.shape[0] > self.token_map.shape[0]:

                loaded_weight = loaded_weight[self.token_map]

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

247
            if self.use_llama_nn and os.environ['LM_NN'] == '1' and "lm_head" in name:
248
249
250
251
252
253
254
255
                _weight = torch.zeros_like(param.data)
                ori_shape =_weight.shape
                
                ops.trans_w16_gemm(_weight, param.data, _weight.shape[0], _weight.shape[1])
                param.data.copy_(_weight)
                
                param.data=param.data.reshape(ori_shape[1],-1)

256
257
258
259
260
        if self.token_map is not None:
            self.token_map.to(device=self.lm_heads[0].weight.device)

        assert (self.truncated_vocab_size
                == self.orig_vocab_size) or (self.token_map is not None)