medusa.py 7.69 KB
Newer Older
1
from typing import Iterable, List, Optional, Set, Tuple
2
3
4
5

import torch
import torch.nn as nn

6
from vllm.config import VllmConfig
7
from vllm.model_executor.layers.logits_processor import LogitsProcessor
8
from vllm.model_executor.layers.sampler import SamplerOutput
9
10
11
12
13
14
15
16
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata


class ResidualBlock(nn.Module):

17
18
    def __init__(self, config: VllmConfig, hidden_size: int,
                 num_layers: int) -> None:
19
20
21
        super().__init__()

        self.layers = nn.ModuleList([
22
23
24
            nn.Linear(hidden_size,
                      hidden_size,
                      bias=getattr(config, "medusa_fc_bias", False))
25
26
27
28
29
30
31
32
33
34
35
            for _ in range(num_layers)
        ])
        self.act = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = x + self.act(layer(x))
        return x


class Medusa(nn.Module):
36
37
38
39
40
41
42
43
44
45
46
47
48
    """This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774
    Reference implementation: https://github.com/FasterDecoding/Medusa
    
    Differences from reference implementation:
    1. Currently this only supports generating proposals from top-1 tokens.
    2. We have an optional token_map which reduces draft vocab to most 
       frequently used tokens to give some additional speed-up by reducing 
       sampling overhead. This is disabled unless the checkpoint file has 
       explicit token_map tensor and config has an optional attribute 
       truncated_vocab_size < vocab_size. To use this technique, one has to find
       the top-k most frequent tokens in target dataset and add that as a tensor
       in the draft checkpoint (using key token_map). Also, the draft config
       needs to have truncated_vocab_size (=k) as an attribute."""
49

50
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
51
        config = vllm_config.model_config.hf_config
52
53
54
        super().__init__()
        self.config = config
        self.blocks = nn.ModuleList([
55
56
            ResidualBlock(config=config,
                          hidden_size=self.config.hidden_size,
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
                          num_layers=self.config.num_hidden_layers)
            for _ in range(self.config.num_heads)
        ])
        self.orig_vocab_size = config.vocab_size
        self.truncated_vocab_size = config.truncated_vocab_size
        self.unpadded_vocab_size = self.truncated_vocab_size

        self.lm_heads = nn.ModuleList([
            ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=self.truncated_vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE,
            ) for _ in range(self.config.num_heads)
        ])

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                self.truncated_vocab_size,
                                                logit_scale)

78
79
80
81
82
83
        # Token map is a idx to token mapping to reduce the vocab size for
        # the draft model. Using smaller vocab size for draft, containing
        # only most frequent tokens reduces the speculation overhead. This
        # doesn't affect the acceptance rate much and thus gives more speed
        # -up. By default, this is disabled and is only used if the EAGLE
        # checkpoint file has token_map tensor.
84
85
86
87
88
89
90
91
        self.token_map = None

    def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
        return [block(hidden_states) for block in self.blocks]

    def compute_logits(
            self, hidden_states: List[torch.Tensor],
            sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
92
        logits_lst: List[torch.Tensor] = []
93
94
95
96

        for hs, lm_head in zip(hidden_states, self.lm_heads):
            _logits = self.logits_processor(lm_head, hs, sampling_metadata)

97
98
99
100
101
102
            if _logits is None:
                # _logits should only be None on rank > 0, in which case
                # it should remain true for every lm_head
                assert len(logits_lst) == 0
                continue

103
            if self.token_map is None:
104
                logits_lst.append(_logits)
105
            else:
106
                logits_lst.append(-torch.inf * torch.ones(
107
108
109
110
                    size=(*_logits.shape[:-1], self.orig_vocab_size),
                    device=_logits.device,
                    dtype=_logits.dtype))

111
                logits_lst[-1][..., self.token_map] = _logits
112

113
        return logits_lst
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

    def sample(
        self,
        logits: List[torch.Tensor],
        sampling_metadata: SamplingMetadata,
    ) -> List[SamplerOutput]:
        logits = torch.stack(logits, dim=0).float()
        logprobs = torch.log_softmax(logits, dim=-1)
        token_ids = logits.argmax(-1)  # support only top-1 for now
        probs = torch.softmax(logits, dim=-1)

        token_id_list = []
        token_prob_list = []
        token_logprob_list = []

        for idx, seq_group in enumerate(sampling_metadata.seq_groups):
            token_id_list.append(token_ids[:, seq_group.sample_indices])
            token_prob_list.append(probs[:, seq_group.sample_indices])
            token_logprob_list.append(logprobs[:, seq_group.sample_indices])

        outputs: List[Optional[SamplerOutput]] = []
        for idx in range(len(sampling_metadata.seq_groups)):
            outputs.append(
                SamplerOutput(
                    outputs=None,
                    sampled_token_probs=token_prob_list[idx].squeeze(1),
                    logprobs=token_logprob_list[idx].squeeze(1),
                    sampled_token_ids=token_id_list[idx].squeeze(1),
                ))

        return outputs

    def generate_proposals(
        self,
        previous_hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> List[SamplerOutput]:
        return self.sample(
            logits=self.compute_logits(
                hidden_states=self.forward(previous_hidden_states),
                sampling_metadata=sampling_metadata,
            ),
            sampling_metadata=sampling_metadata,
        )

159
160
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
161
        params_dict = dict(self.named_parameters())
162
        loaded_params: Set[str] = set()
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

        weights_map = {}

        for name, loaded_weight in weights:
            name = name.replace("medusa_heads.", "")

            if name == "token_map":
                if self.truncated_vocab_size < self.orig_vocab_size:
                    self.token_map = nn.Parameter(loaded_weight,
                                                  requires_grad=False)
            elif name in params_dict:
                weights_map[name] = loaded_weight

        for name, loaded_weight in weights_map.items():
            if "lm_head" in name and self.token_map is not None and\
                loaded_weight.shape[0] > self.token_map.shape[0]:

                loaded_weight = loaded_weight[self.token_map]

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
186
            loaded_params.add(name)
187
188
189
190
191
192

        if self.token_map is not None:
            self.token_map.to(device=self.lm_heads[0].weight.device)

        assert (self.truncated_vocab_size
                == self.orig_vocab_size) or (self.token_map is not None)
193
194

        return loaded_params