llama_eagle3.py 11.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Iterable
from typing import Optional
6
7
8
9
10

import torch
import torch.nn as nn
from transformers import LlamaConfig

11
from vllm.config import VllmConfig, get_current_vllm_config
12
13
14
15
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
17
from vllm.model_executor.layers.vocab_parallel_embedding import (
18
19
20
21
    DEFAULT_VOCAB_PADDING_SIZE,
    ParallelLMHead,
    VocabParallelEmbedding,
)
22
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
24
25
26
27
28
29
30

from .utils import AutoWeightsLoader, maybe_prefix

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):
31
32
33
34
35
36
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        config: Optional[LlamaConfig] = None,
    ) -> None:
37
38
39
        super().__init__(vllm_config, prefix=prefix, config=config)

        config = config or vllm_config.model_config.hf_config
40
        quant_config = self.get_quant_config(vllm_config)
41
42
43
44
45
46
47
48
49
50
51
52
53
54

        # override qkv
        self.self_attn.qkv_proj = QKVParallelLinear(
            2 * self.hidden_size,
            self.self_attn.head_dim,
            self.self_attn.total_num_heads,
            self.self_attn.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "qkv_proj"),
        )

        self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

55
56
57
58
59
        if getattr(config, "norm_before_residual", False):
            self._residual_norm = self._norm_before_residual
        else:
            self._residual_norm = self._norm_after_residual

60
    def get_quant_config(self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]:
61
62
63
64
        """Use drafter's quantization config instead of verifier's."""
        draft_model_config = vllm_config.speculative_config.draft_model_config
        draft_load_config = vllm_config.load_config

65
66
67
68
69
        return (
            VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
            if draft_model_config
            else None
        )
70

71
    def _norm_before_residual(
72
73
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
74
75
76
77
78
        hidden_states = self.hidden_norm(hidden_states)
        residual = hidden_states
        return hidden_states, residual

    def _norm_after_residual(
79
80
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
81
82
83
84
        residual = hidden_states
        hidden_states = self.hidden_norm(hidden_states)
        return hidden_states, residual

85
86
87
88
89
90
    def forward(
        self,
        positions: torch.Tensor,
        embeds: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
91
    ) -> tuple[torch.Tensor, torch.Tensor]:
92
        embeds = self.input_layernorm(embeds)
93

94
        hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
95
96
97
98
99
100
101
102

        hidden_states = torch.cat([embeds, hidden_states], dim=-1)
        # Self Attention
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

103
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
104
105
106
107
108
109
110
111
112
113
114

        # Fully Connected
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


class LlamaModel(nn.Module):
    def __init__(
        self,
        *,
115
        vllm_config: VllmConfig,
116
117
118
119
        start_layer_id: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()
120
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
121
        self.vocab_size = self.config.vocab_size
122

123
124
        current_vllm_config = get_current_vllm_config()

125
126
127
128
129
        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "embed_tokens"),
        )
130

131
132
133
134
135
136
137
138
139
        self.layers = nn.ModuleList(
            [
                LlamaDecoderLayer(
                    current_vllm_config,
                    prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
                    config=self.config,
                )
            ]
        )
140
        if hasattr(self.config, "target_hidden_size"):
141
142
143
            self.fc = torch.nn.Linear(
                self.config.target_hidden_size * 3, self.config.hidden_size, bias=False
            )
144
        else:
145
146
147
            self.fc = torch.nn.Linear(
                self.config.hidden_size * 3, self.config.hidden_size, bias=False
            )
148
149
150
151
152
        self.norm = RMSNorm(
            self.config.hidden_size,
            eps=self.config.rms_norm_eps,
        )

153
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
154
155
        return self.embed_tokens(input_ids)

156
157
158
159
160
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
161
        input_embeds: Optional[torch.Tensor] = None,
162
    ) -> tuple[torch.Tensor, torch.Tensor]:
163
164
        if input_embeds is None:
            input_embeds = self.get_input_embeddings(input_ids)
165
        assert hidden_states.shape[-1] == input_embeds.shape[-1]
166
167
168
169
170
171
172
173
174
175
176
177

        residual = None
        hidden_states, residual = self.layers[0](
            positions,
            input_embeds,
            hidden_states,
            residual,
        )

        hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
        return hidden_states, hidden_prenorm

178
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
179
180
181
182
183
184
185
186
187
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
188
        loaded_params: set[str] = set()
189
        for name, loaded_weight in weights:
190
191
            if "midlayer." in name:
                name = name.replace("midlayer.", "layers.0.")
192
193
194
195
196
197
198
199
200
201
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
202
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
203
204
205
206
207
208
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Eagle3LlamaForCausalLM(LlamaForCausalLM):
209
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
210
        nn.Module.__init__(self)
211
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
212
213
214
215
216
        # Ensure draft_vocab_size is set
        # default to the base vocab size when absent
        if getattr(self.config, "draft_vocab_size", None) is None:
            base_vocab_size = getattr(self.config, "vocab_size", None)
            self.config.draft_vocab_size = base_vocab_size
217
        target_layer_num = vllm_config.model_config.get_num_layers(
218
219
            vllm_config.parallel_config
        )
220
221
222
223

        # Store target layer count in draft config for
        # proper layer_types indexing in draft models
        self.config.target_layer_count = target_layer_num
224
225
226
        self.model = LlamaModel(
            vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num
        )
227
228
229
230
231
232
233

        logit_scale = getattr(self.config, "logit_scale", 1.0)
        self.lm_head = ParallelLMHead(
            self.config.draft_vocab_size,
            self.config.hidden_size,
            org_num_embeddings=self.config.draft_vocab_size,
            padding_size=(DEFAULT_VOCAB_PADDING_SIZE),
234
235
236
237
238
            prefix=maybe_prefix(prefix, "lm_head"),
        )
        self.logits_processor = LogitsProcessor(
            self.config.draft_vocab_size, scale=logit_scale
        )
239
        self.draft_id_to_target_id = nn.Parameter(
240
            torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
241
242
243
            requires_grad=False,
        )

244
245
246
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

247
248
249
250
251
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
252
        inputs_embeds: Optional[torch.Tensor] = None,
253
    ) -> tuple[torch.Tensor, torch.Tensor]:
254
        return self.model(input_ids, positions, hidden_states, inputs_embeds)
255
256
257
258
259

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
260
        logits = self.logits_processor(self.lm_head, hidden_states)
261
        if self.draft_id_to_target_id is None:
262
263
            assert logits.shape[1] == self.config.vocab_size, (
                "Expected logits to have shape "
264
                f"(*, {self.config.vocab_size}), but got {logits.shape}"
265
            )
266
267
            return logits

268
269
        base = torch.arange(self.config.draft_vocab_size, device=logits.device)
        targets = base + self.draft_id_to_target_id
270
271
272
273
274
275
276
        logits_new = logits.new_full(
            (
                logits.shape[0],
                self.config.vocab_size,
            ),
            float("-inf"),
        )
277
278
279
        logits_new[:, targets] = logits
        return logits_new

280
281
282
283
284
285
286
    def combine_hidden_states(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # combine multiple auxiliary hidden states returned by eagle3
        return self.model.fc(hidden_states)

287
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
288
        model_weights = {}
289
        includes_draft_id_mapping = False
290
        includes_embed_tokens = False
291
292
293
294
295
        for name, loaded_weight in weights:
            if "t2d" in name:
                continue
            if "d2t" in name:
                name = name.replace("d2t", "draft_id_to_target_id")
296
                includes_draft_id_mapping = True
297
298
            elif "lm_head" not in name:
                name = "model." + name
299
300
            if "embed_tokens" in name:
                includes_embed_tokens = True
301
302
            model_weights[name] = loaded_weight

303
304
305
306
307
        skip_substrs = []
        if not includes_draft_id_mapping:
            skip_substrs.append("draft_id_to_target_id")
        if not includes_embed_tokens:
            skip_substrs.append("embed_tokens")
308
309
310
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=None,
311
            skip_substrs=skip_substrs,
312
313
        )
        loader.load_weights(model_weights.items())