eagle.py 11.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Iterable
from typing import Optional
6
7
8
9

import torch
import torch.nn as nn

10
from vllm.config import VllmConfig
11
from vllm.logger import init_logger
12
from vllm.model_executor.layers.layernorm import RMSNorm
13
14
15
16
17
18
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.sampling_metadata import SamplingMetadata
19
from vllm.sequence import IntermediateTensors
20

21
22
from .utils import maybe_prefix

23
24
logger = init_logger(__name__)

25

26
27
class DummyInputLayerNorm(nn.Module):

28
29
30
31
32
    def __init__(self, weight=None, bias=None):
        super().__init__()
        self.weight = nn.Parameter(weight) if weight is not None else None
        self.bias = nn.Parameter(bias) if bias is not None else None

33
34
35
36
37
38
39
40
41
42
    def forward(self, x):
        return x


class DummyOutputNorm(nn.Module):

    def forward(self, x, residual):
        if residual is None:
            return x
        else:
43
            return x + residual, None
44
45


46
47
48
49
50
51
class EAGLE(nn.Module):
    """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
    Reference implementation: https://github.com/SafeAILab/EAGLE
    
    Differences from reference implementation:
    1. In reference, LlamaDecoderLayer implementation doesn't have 
52
53
54
       input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
       Following this approach, our implementation also disables
       the input_layernorm for the first decoder layer.
55
56
57
58
59
60
61
62
63
    2. We allow any decoder layer to be used in EAGLE whereas in reference 
       decoder layer is fixed to be LlamaDecoderLayer.
    3. We have an optional token_map which reduces draft vocab to most 
       frequently used tokens to give some additional speed-up by reducing 
       sampling overhead. This is disabled unless the checkpoint file has 
       explicit token_map tensor and config has an optional attribute 
       truncated_vocab_size < vocab_size. To use this technique, one has to find
       the top-k most frequent tokens in target dataset and add that as a tensor
       in the draft checkpoint (using key token_map). Also, the draft config
64
65
66
67
68
69
70
71
72
       needs to have truncated_vocab_size (=k) as an attribute.
    4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP 
       module with regards to the use of additional RMS norms. The original 
       EAGLE architecture 1) skips the pre-attention norm in its first 
       transformer block, and 2) skips the final output norm, both of which we 
       found to be suboptimal. We also add the support for separate norms
       applying to both the token embedding and hidden states before projection
       as in DeepSeek MTP, which we found to improve performance as well.
    """
73

74
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
75
        super().__init__()
76
        config = vllm_config.model_config.hf_config
77
78
79
80
81
        self.config = config

        architectures = getattr(self.config.model, "architectures", [])
        model_cls, _ = ModelRegistry.resolve_model_cls(architectures)

82
83
        self.model = model_cls(vllm_config=vllm_config,
                               prefix=maybe_prefix(prefix, "model"))
84

85
86
        self.fc = nn.Linear(config.model.hidden_size * 2,
                            config.model.hidden_size,
87
                            bias=getattr(self.config, "eagle_fc_bias", False))
88

89
90
        # Modify layer normalization and residual connections as suggested
        # in the EAGLE framework: https://github.com/SafeAILab/EAGLE
91
92
93
        # While weights and biases are generally not needed,
        # they are retained here to support certain unit tests
        # (e.g., spec_decode/e2e/test_eagle_correctness.py).
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        if not hasattr(self.config.model,
                       "skip_prenorm") or self.config.model.skip_prenorm:
            self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
                weight=self.model.model.layers[0].input_layernorm.weight)

        if not hasattr(
                self.config.model,
                "skip_output_norm") or self.config.model.skip_output_norm:
            self.model.model.norm = DummyOutputNorm()

        self.add_para_norm = False
        if hasattr(self.config.model,
                   "add_para_norm") and self.config.model.add_para_norm:
            self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            self.add_para_norm = True
110

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        self.orig_vocab_size = config.vocab_size
        self.truncated_vocab_size = config.truncated_vocab_size
        self.unpadded_vocab_size = self.truncated_vocab_size

        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=self.truncated_vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
        )

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                self.truncated_vocab_size,
                                                logit_scale)

        # Token map is a idx to token mapping to reduce the vocab size for
        # the draft model. Using smaller vocab size for draft, containing
        # only most frequent tokens reduces the speculation overhead. This
        # doesn't affect the acceptance rate much and thus gives more speed
        # -up. By default, this is disabled and is only used if the EAGLE
        # checkpoint file has token_map tensor.
        self.token_map = None

135
136
137
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.model.get_input_embeddings(input_ids)

138
139
140
141
142
143
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
144
        inputs_embeds: Optional[torch.Tensor] = None,
145
146
    ) -> torch.Tensor:

147
148
149
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings(input_ids)

150
151
152
153
154
155
156
157
158
159
160
        # Handle both empty previous_hidden_states
        # and mismatched batch size
        batch_size = inputs_embeds.size(0)
        if previous_hidden_states.size(0) == 0 or \
           previous_hidden_states.size(0) != batch_size:
            hidden_dim = self.config.model.hidden_size
            device = inputs_embeds.device
            # Create zero tensor with matching batch size
            previous_hidden_states = \
                torch.zeros(batch_size, hidden_dim, device=device)

161
162
163
164
165
166
167
168
169
170
171
        if self.add_para_norm:
            inputs_embeds = torch.cat([
                self.enorm(inputs_embeds),
                self.hnorm(previous_hidden_states)
            ],
                                      dim=-1)
        else:
            inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states],
                                      dim=-1)

        inputs_embeds = self.fc(inputs_embeds)
172
173
174
175
176
177
178

        inputs_embeds[positions == 0] = 0  # masking inputs at position=0

        hidden_states = self.model.model(
            input_ids=None,
            inputs_embeds=inputs_embeds,
            positions=positions,
179
180
            intermediate_tensors=intermediate_tensors,
        )
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)

        if self.token_map is not None:
            _logits = logits
            logits = -torch.inf * torch.ones(
                size=(*_logits.shape[:-1], self.orig_vocab_size),
                device=_logits.device,
                dtype=_logits.dtype)

            logits[..., self.token_map] = _logits

        return logits

199
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
200
201
202
203
204
205
206
207
208
209
210
211
        # This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
        # due to missing lm_head weights and its config being that of a
        # Llama model. Here's a compatible version with the same weights:
        # https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
        # Also, here's an example script for converting trained EAGLE
        # checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
        model_weights = {}
        for name, loaded_weight in weights:
            if name == "token_map":
                if self.config.truncated_vocab_size < self.config.vocab_size:
                    self.token_map = nn.Parameter(loaded_weight,
                                                  requires_grad=False)
212
            elif name.startswith("fc.weight"):
213
214
215
                weight_loader = getattr(self.fc.weight, "weight_loader",
                                        default_weight_loader)
                weight_loader(self.fc.weight, loaded_weight)
216
217
218
219
220
221
            elif name.startswith("fc.bias"):
                if self.fc.bias is not None:
                    weight_loader = getattr(self.fc.bias, "weight_loader",
                                            default_weight_loader)
                    weight_loader(self.fc.bias, loaded_weight)
                else:
222
223
                    logger.warning_once("Found bias in the loaded weights but "
                                        "the model config doesn't have bias.")
224
225
226
227
228
229
230
231
            elif name.startswith("enorm.weight"):
                weight_loader = getattr(self.enorm.weight, "weight_loader",
                                        default_weight_loader)
                weight_loader(self.enorm.weight, loaded_weight)
            elif name.startswith("hnorm.weight"):
                weight_loader = getattr(self.hnorm.weight, "weight_loader",
                                        default_weight_loader)
                weight_loader(self.hnorm.weight, loaded_weight)
232
233
234
235
236
237
238
239
            elif name.startswith("model.lm_head.") or name.startswith(
                    "model.model."):
                model_weights[name.split("model.", 1)[-1]] = loaded_weight
            elif name.startswith("lm_head.") or name.startswith("model."):
                model_weights[name] = loaded_weight
            else:
                model_weights[f"model.{name}"] = loaded_weight

240
241
242
243
244
        if "lm_head.weight" in model_weights:
            lm_head_weight = model_weights.pop("lm_head.weight")

            if self.token_map is not None and\
                lm_head_weight.shape[0] > self.token_map.shape[0]:
245

246
                lm_head_weight = lm_head_weight[self.token_map]
247

248
249
250
251
252
253
254
        else:
            # NOTE(Shangming): initialize the placeholder for lm_head weight.
            lm_head_weight = torch.zeros(
                self.lm_head.org_vocab_size,
                self.lm_head.embedding_dim,
                dtype=self.config.torch_dtype,
            )
255
256
257
258
259
260

        weight_loader = getattr(self.lm_head.weight, "weight_loader",
                                default_weight_loader)
        weight_loader(self.lm_head.weight, lm_head_weight)

        self.model.load_weights(model_weights.items())