llama_eagle3.py 12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Iterable
from typing import Optional
6
7
8
9
10

import torch
import torch.nn as nn
from transformers import LlamaConfig

11
from vllm.config import VllmConfig, get_current_vllm_config
12
13
14
15
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
17
from vllm.model_executor.layers.vocab_parallel_embedding import (
18
19
20
21
    DEFAULT_VOCAB_PADDING_SIZE,
    ParallelLMHead,
    VocabParallelEmbedding,
)
22
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
24
from vllm.multimodal.inputs import NestedTensors
25
26
27
28
29
30
31

from .utils import AutoWeightsLoader, maybe_prefix

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):
32
33
34
35
36
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        config: Optional[LlamaConfig] = None,
37
        layer_idx: int = 0,
38
    ) -> None:
39
40
41
        super().__init__(vllm_config, prefix=prefix, config=config)

        config = config or vllm_config.model_config.hf_config
42
        quant_config = self.get_quant_config(vllm_config)
43

44
45
46
47
        # First layer uses 2*hidden_size (embeds + hidden_states concatenated)
        # Subsequent layers use hidden_size (only hidden_states, no embeds)
        qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size

48
49
        # override qkv
        self.self_attn.qkv_proj = QKVParallelLinear(
50
            qkv_input_size,
51
52
53
54
55
56
57
58
59
            self.self_attn.head_dim,
            self.self_attn.total_num_heads,
            self.self_attn.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "qkv_proj"),
        )

        self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
60
        self.layer_idx = layer_idx
61

62
63
64
65
66
        if getattr(config, "norm_before_residual", False):
            self._residual_norm = self._norm_before_residual
        else:
            self._residual_norm = self._norm_after_residual

67
    def get_quant_config(self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]:
68
69
70
71
        """Use drafter's quantization config instead of verifier's."""
        draft_model_config = vllm_config.speculative_config.draft_model_config
        draft_load_config = vllm_config.load_config

72
73
74
75
76
        return (
            VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
            if draft_model_config
            else None
        )
77

78
    def _norm_before_residual(
79
80
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
81
82
83
84
85
        hidden_states = self.hidden_norm(hidden_states)
        residual = hidden_states
        return hidden_states, residual

    def _norm_after_residual(
86
87
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
88
89
90
91
        residual = hidden_states
        hidden_states = self.hidden_norm(hidden_states)
        return hidden_states, residual

92
93
94
95
96
97
    def forward(
        self,
        positions: torch.Tensor,
        embeds: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
98
    ) -> tuple[torch.Tensor, torch.Tensor]:
99
100
101
102
103
104
105
106
        if self.layer_idx == 0:
            # First layer: concatenate embeds with hidden_states
            embeds = self.input_layernorm(embeds)
            hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
            hidden_states = torch.cat([embeds, hidden_states], dim=-1)
        else:
            # Subsequent layers: process hidden_states and residuals only
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
107
108
109
110
111
112
113

        # Self Attention
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

114
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
115
116
117
118
119
120
121
122
123
124
125

        # Fully Connected
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


class LlamaModel(nn.Module):
    def __init__(
        self,
        *,
126
        vllm_config: VllmConfig,
127
128
129
130
        start_layer_id: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()
131
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
132
        self.vocab_size = self.config.vocab_size
133

134
135
        current_vllm_config = get_current_vllm_config()

136
137
138
139
140
        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "embed_tokens"),
        )
141

142
143
144
145
        self.layers = nn.ModuleList(
            [
                LlamaDecoderLayer(
                    current_vllm_config,
146
                    prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
147
                    config=self.config,
148
                    layer_idx=layer_idx,
149
                )
150
                for layer_idx in range(self.config.num_hidden_layers)
151
152
            ]
        )
153
        if hasattr(self.config, "target_hidden_size"):
154
155
156
            self.fc = torch.nn.Linear(
                self.config.target_hidden_size * 3, self.config.hidden_size, bias=False
            )
157
        else:
158
159
160
            self.fc = torch.nn.Linear(
                self.config.hidden_size * 3, self.config.hidden_size, bias=False
            )
161
162
163
164
165
        self.norm = RMSNorm(
            self.config.hidden_size,
            eps=self.config.rms_norm_eps,
        )

166
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
167
168
        return self.embed_tokens(input_ids)

169
170
171
172
173
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
174
        input_embeds: Optional[torch.Tensor] = None,
175
    ) -> tuple[torch.Tensor, torch.Tensor]:
176
177
        if input_embeds is None:
            input_embeds = self.get_input_embeddings(input_ids)
178
        assert hidden_states.shape[-1] == input_embeds.shape[-1]
179
180

        residual = None
181
182
183
184
185
186
187
        for layer in self.layers:
            hidden_states, residual = layer(
                positions=positions,
                embeds=input_embeds,
                hidden_states=hidden_states,
                residual=residual,
            )
188
189
190
        hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
        return hidden_states, hidden_prenorm

191
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
192
193
194
195
196
197
198
199
200
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
201
        loaded_params: set[str] = set()
202
        for name, loaded_weight in weights:
203
204
            if "midlayer." in name:
                name = name.replace("midlayer.", "layers.0.")
205
206
207
208
209
210
211
212
213
214
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
215
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
216
217
218
219
220
221
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Eagle3LlamaForCausalLM(LlamaForCausalLM):
222
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
223
        nn.Module.__init__(self)
224
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
225
226
227
228
229
        # Ensure draft_vocab_size is set
        # default to the base vocab size when absent
        if getattr(self.config, "draft_vocab_size", None) is None:
            base_vocab_size = getattr(self.config, "vocab_size", None)
            self.config.draft_vocab_size = base_vocab_size
230
        target_layer_num = vllm_config.model_config.get_num_layers(
231
232
            vllm_config.parallel_config
        )
233
234
235
236

        # Store target layer count in draft config for
        # proper layer_types indexing in draft models
        self.config.target_layer_count = target_layer_num
237
238
239
        self.model = LlamaModel(
            vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num
        )
240
241
242
243
244
245
246

        logit_scale = getattr(self.config, "logit_scale", 1.0)
        self.lm_head = ParallelLMHead(
            self.config.draft_vocab_size,
            self.config.hidden_size,
            org_num_embeddings=self.config.draft_vocab_size,
            padding_size=(DEFAULT_VOCAB_PADDING_SIZE),
247
248
249
250
251
            prefix=maybe_prefix(prefix, "lm_head"),
        )
        self.logits_processor = LogitsProcessor(
            self.config.draft_vocab_size, scale=logit_scale
        )
252
        self.draft_id_to_target_id = nn.Parameter(
253
            torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
254
255
256
            requires_grad=False,
        )

257
258
259
260
261
262
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
        is_multimodal: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
263
264
        return self.model.get_input_embeddings(input_ids)

265
266
267
268
269
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
270
        inputs_embeds: Optional[torch.Tensor] = None,
271
    ) -> tuple[torch.Tensor, torch.Tensor]:
272
        return self.model(input_ids, positions, hidden_states, inputs_embeds)
273
274
275
276
277

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
278
        logits = self.logits_processor(self.lm_head, hidden_states)
279
        if self.draft_id_to_target_id is None:
280
281
            assert logits.shape[1] == self.config.vocab_size, (
                "Expected logits to have shape "
282
                f"(*, {self.config.vocab_size}), but got {logits.shape}"
283
            )
284
285
            return logits

286
287
        base = torch.arange(self.config.draft_vocab_size, device=logits.device)
        targets = base + self.draft_id_to_target_id
288
289
290
291
292
293
294
        logits_new = logits.new_full(
            (
                logits.shape[0],
                self.config.vocab_size,
            ),
            float("-inf"),
        )
295
296
297
        logits_new[:, targets] = logits
        return logits_new

298
299
300
301
302
303
304
    def combine_hidden_states(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # combine multiple auxiliary hidden states returned by eagle3
        return self.model.fc(hidden_states)

305
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
306
        model_weights = {}
307
        includes_draft_id_mapping = False
308
        includes_embed_tokens = False
309
310
311
312
313
        for name, loaded_weight in weights:
            if "t2d" in name:
                continue
            if "d2t" in name:
                name = name.replace("d2t", "draft_id_to_target_id")
314
                includes_draft_id_mapping = True
315
316
            elif "lm_head" not in name:
                name = "model." + name
317
318
            if "embed_tokens" in name:
                includes_embed_tokens = True
319
320
            model_weights[name] = loaded_weight

321
322
323
324
325
        skip_substrs = []
        if not includes_draft_id_mapping:
            skip_substrs.append("draft_id_to_target_id")
        if not includes_embed_tokens:
            skip_substrs.append("embed_tokens")
326
327
328
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=None,
329
            skip_substrs=skip_substrs,
330
331
        )
        loader.load_weights(model_weights.items())