eagle.py 8.6 KB
Newer Older
1
2
3
4
5
6
from typing import Iterable, List, Optional, Tuple

import torch
import torch.nn as nn

from vllm.attention.backends.abstract import AttentionMetadata
7
from vllm.config import VllmConfig
8
from vllm.model_executor.layers.logits_processor import LogitsProcessor
9
from vllm.model_executor.layers.sampler import SamplerOutput
10
11
12
13
14
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.sampling_metadata import SamplingMetadata
15
from vllm.sequence import IntermediateTensors
16

17
18
from .utils import maybe_prefix

19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class DummyInputLayerNorm(nn.Module):

    def forward(self, x):
        return x


class DummyOutputNorm(nn.Module):

    def forward(self, x, residual):
        if residual is None:
            return x
        else:
            return x, residual


35
36
37
38
39
40
class EAGLE(nn.Module):
    """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
    Reference implementation: https://github.com/SafeAILab/EAGLE
    
    Differences from reference implementation:
    1. In reference, LlamaDecoderLayer implementation doesn't have 
41
42
43
       input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
       Following this approach, our implementation also disables
       the input_layernorm for the first decoder layer.
44
45
46
47
48
49
50
51
52
53
54
    2. We allow any decoder layer to be used in EAGLE whereas in reference 
       decoder layer is fixed to be LlamaDecoderLayer.
    3. We have an optional token_map which reduces draft vocab to most 
       frequently used tokens to give some additional speed-up by reducing 
       sampling overhead. This is disabled unless the checkpoint file has 
       explicit token_map tensor and config has an optional attribute 
       truncated_vocab_size < vocab_size. To use this technique, one has to find
       the top-k most frequent tokens in target dataset and add that as a tensor
       in the draft checkpoint (using key token_map). Also, the draft config
       needs to have truncated_vocab_size (=k) as an attribute."""

55
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
56
        super().__init__()
57
        config = vllm_config.model_config.hf_config
58
59
60
61
62
        self.config = config

        architectures = getattr(self.config.model, "architectures", [])
        model_cls, _ = ModelRegistry.resolve_model_cls(architectures)

63
64
        self.model = model_cls(vllm_config=vllm_config,
                               prefix=maybe_prefix(prefix, "model"))
65

66
67
        self.fc = nn.Linear(config.model.hidden_size * 2,
                            config.model.hidden_size,
68
                            bias=getattr(self.config, "eagle_fc_bias", False))
69

70
71
72
73
74
        # Modify layer normalization and residual connections as suggested
        # in the EAGLE framework: https://github.com/SafeAILab/EAGLE
        self.model.model.layers[0].input_layernorm = DummyInputLayerNorm()
        self.model.model.norm = DummyOutputNorm()

75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        self.orig_vocab_size = config.vocab_size
        self.truncated_vocab_size = config.truncated_vocab_size
        self.unpadded_vocab_size = self.truncated_vocab_size

        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=self.truncated_vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
        )

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                self.truncated_vocab_size,
                                                logit_scale)

        # Token map is a idx to token mapping to reduce the vocab size for
        # the draft model. Using smaller vocab size for draft, containing
        # only most frequent tokens reduces the speculation overhead. This
        # doesn't affect the acceptance rate much and thus gives more speed
        # -up. By default, this is disabled and is only used if the EAGLE
        # checkpoint file has token_map tensor.
        self.token_map = None

    @property
    def sampler(self):
        return self.model.sampler

103
104
105
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.model.get_input_embeddings(input_ids)

106
107
108
109
110
111
112
113
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        previous_hidden_states: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
114
        inputs_embeds: Optional[torch.Tensor] = None,
115
116
    ) -> torch.Tensor:

117
118
119
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings(input_ids)

120
        inputs_embeds = self.fc(
121
            torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
122
123
124
125
126
127
128
129
130

        inputs_embeds[positions == 0] = 0  # masking inputs at position=0

        hidden_states = self.model.model(
            input_ids=None,
            inputs_embeds=inputs_embeds,
            positions=positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
131
132
            intermediate_tensors=intermediate_tensors,
        )
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)

        if self.token_map is not None:
            _logits = logits
            logits = -torch.inf * torch.ones(
                size=(*_logits.shape[:-1], self.orig_vocab_size),
                device=_logits.device,
                dtype=_logits.dtype)

            logits[..., self.token_map] = _logits

        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        # This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
        # due to missing lm_head weights and its config being that of a
        # Llama model. Here's a compatible version with the same weights:
        # https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
        # Also, here's an example script for converting trained EAGLE
        # checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
        model_weights = {}
        for name, loaded_weight in weights:
            if name == "token_map":
                if self.config.truncated_vocab_size < self.config.vocab_size:
                    self.token_map = nn.Parameter(loaded_weight,
                                                  requires_grad=False)
172
            elif name.startswith("fc.weight"):
173
174
175
                weight_loader = getattr(self.fc.weight, "weight_loader",
                                        default_weight_loader)
                weight_loader(self.fc.weight, loaded_weight)
176
177
178
179
180
181
182
183
            elif name.startswith("fc.bias"):
                if self.fc.bias is not None:
                    weight_loader = getattr(self.fc.bias, "weight_loader",
                                            default_weight_loader)
                    weight_loader(self.fc.bias, loaded_weight)
                else:
                    raise ValueError("Found bias in the loaded weights "
                                     "but the model config doesn't have bias")
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
            elif name.startswith("model.lm_head.") or name.startswith(
                    "model.model."):
                model_weights[name.split("model.", 1)[-1]] = loaded_weight
            elif name.startswith("lm_head.") or name.startswith("model."):
                model_weights[name] = loaded_weight
            else:
                model_weights[f"model.{name}"] = loaded_weight

        lm_head_weight = model_weights.pop("lm_head.weight")

        if self.token_map is not None and\
            lm_head_weight.shape[0] > self.token_map.shape[0]:

            lm_head_weight = lm_head_weight[self.token_map]

        weight_loader = getattr(self.lm_head.weight, "weight_loader",
                                default_weight_loader)
        weight_loader(self.lm_head.weight, lm_head_weight)

        self.model.load_weights(model_weights.items())