llama_eagle3.py 15.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
6
7
8
9

import torch
import torch.nn as nn
from transformers import LlamaConfig

10
from vllm.compilation.decorators import support_torch_compile
11
from vllm.config import VllmConfig, get_current_vllm_config
12
13
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
14
from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear
15
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
17
from vllm.model_executor.layers.vocab_parallel_embedding import (
18
19
20
    ParallelLMHead,
    VocabParallelEmbedding,
)
21
22
23
24
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
25
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
26
from vllm.multimodal.inputs import NestedTensors
27

28
29
30
31
32
33
from .utils import (
    AutoWeightsLoader,
    get_draft_quant_config,
    maybe_prefix,
    process_eagle_weight,
)
34
35
36
37
38

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):
39
40
41
42
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
43
        config: LlamaConfig | None = None,
44
        layer_idx: int = 0,
45
    ) -> None:
46
47
48
        super().__init__(vllm_config, prefix=prefix, config=config)

        config = config or vllm_config.model_config.hf_config
49
        quant_config = self.get_quant_config(vllm_config)
50

51
52
53
54
        # First layer uses 2*hidden_size (embeds + hidden_states concatenated)
        # Subsequent layers use hidden_size (only hidden_states, no embeds)
        qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size

55
56
57
58
        # Parallel drafting checkpoints may have attention bias enabled
        qkv_bias = getattr(config, "attention_bias", False)

        # Override qkv_proj with correct input size and bias setting
59
        self.self_attn.qkv_proj = QKVParallelLinear(
60
            qkv_input_size,
61
62
63
            self.self_attn.head_dim,
            self.self_attn.total_num_heads,
            self.self_attn.total_num_kv_heads,
64
            bias=qkv_bias,
65
66
67
68
69
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "qkv_proj"),
        )

        self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
70
        self.layer_idx = layer_idx
71

72
73
74
75
76
        if getattr(config, "norm_before_residual", False):
            self._residual_norm = self._norm_before_residual
        else:
            self._residual_norm = self._norm_after_residual

77
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
78
        """Use drafter's quantization config instead of verifier's."""
79
        return get_draft_quant_config(vllm_config)
80

81
    def _norm_before_residual(
82
83
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
84
85
86
87
88
        hidden_states = self.hidden_norm(hidden_states)
        residual = hidden_states
        return hidden_states, residual

    def _norm_after_residual(
89
90
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
91
92
93
94
        residual = hidden_states
        hidden_states = self.hidden_norm(hidden_states)
        return hidden_states, residual

95
96
97
98
99
    def forward(
        self,
        positions: torch.Tensor,
        embeds: torch.Tensor,
        hidden_states: torch.Tensor,
100
        residual: torch.Tensor | None,
101
    ) -> tuple[torch.Tensor, torch.Tensor]:
102
103
104
105
106
107
108
109
        if self.layer_idx == 0:
            # First layer: concatenate embeds with hidden_states
            embeds = self.input_layernorm(embeds)
            hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
            hidden_states = torch.cat([embeds, hidden_states], dim=-1)
        else:
            # Subsequent layers: process hidden_states and residuals only
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
110
111
112
113
114
115
116

        # Self Attention
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

117
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
118
119
120
121
122
123
124

        # Fully Connected
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


125
@support_torch_compile(
126
127
128
129
130
131
    dynamic_arg_dims={
        "input_ids": 0,
        "positions": -1,
        "hidden_states": 0,
        "input_embeds": 0,
    }
132
)
133
134
135
136
class LlamaModel(nn.Module):
    def __init__(
        self,
        *,
137
        vllm_config: VllmConfig,
138
139
140
141
        start_layer_id: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()
142
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
143
        self.vocab_size = self.config.vocab_size
144

145
146
147
        # Get drafter's quantization config
        self.quant_config = get_draft_quant_config(vllm_config)

148
149
150
151
152
        eagle_config = getattr(self.config, "eagle_config", None)
        if eagle_config is not None and "use_aux_hidden_state" in eagle_config:
            self.use_aux_hidden_state = eagle_config["use_aux_hidden_state"]
        else:
            self.use_aux_hidden_state = True
153
        self.norm_before_fc = getattr(self.config, "norm_before_fc", False)
154

155
156
        current_vllm_config = get_current_vllm_config()

157
158
159
160
161
        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "embed_tokens"),
        )
162

163
164
165
166
        self.layers = nn.ModuleList(
            [
                LlamaDecoderLayer(
                    current_vllm_config,
167
                    prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
168
                    config=self.config,
169
                    layer_idx=layer_idx,
170
                )
171
                for layer_idx in range(self.config.num_hidden_layers)
172
173
            ]
        )
174
175
176
177
178
        if self.use_aux_hidden_state:
            if hasattr(self.config, "target_hidden_size"):
                fc_input_size = self.config.target_hidden_size * 3
            else:
                fc_input_size = self.config.hidden_size * 3
179
180
181
182
183
184
185
            if self.norm_before_fc:
                self.input_norm = RMSNorm(
                    fc_input_size,
                    eps=self.config.rms_norm_eps,
                )
            else:
                self.input_norm = None
186
187
188
189
190
191
192
193
194
            self.fc = ReplicatedLinear(
                input_size=fc_input_size,
                output_size=self.config.hidden_size,
                bias=False,
                params_dtype=vllm_config.model_config.dtype,
                quant_config=self.quant_config,
                prefix=maybe_prefix(prefix, "fc"),
                return_bias=False,
            )
195
196
197
198
199
        self.norm = RMSNorm(
            self.config.hidden_size,
            eps=self.config.rms_norm_eps,
        )

200
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
201
202
        return self.embed_tokens(input_ids)

203
204
205
206
207
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
208
        input_embeds: torch.Tensor | None = None,
209
    ) -> tuple[torch.Tensor, torch.Tensor]:
210
        if input_embeds is None:
211
            input_embeds = self.embed_input_ids(input_ids)
212
        assert hidden_states.shape[-1] == input_embeds.shape[-1]
213
214

        residual = None
215
216
217
218
219
220
221
        for layer in self.layers:
            hidden_states, residual = layer(
                positions=positions,
                embeds=input_embeds,
                hidden_states=hidden_states,
                residual=residual,
            )
222
223
224
        hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
        return hidden_states, hidden_prenorm

225
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
226
227
228
229
230
231
232
233
234
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
235
        loaded_params: set[str] = set()
236
        for name, loaded_weight in weights:
237
238
            if "midlayer." in name:
                name = name.replace("midlayer.", "layers.0.")
239
240
241
242
243
244
245
246
247
248
249
250
251
            # Handle kv cache quantization scales
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
252
253
            # Remapping the name FP8 kv-scale or zero point.
            if "scale" in name or "zero_point" in name:
254
255
256
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
257
258
259
260
261
262
263
264
265
266
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
267
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
268
269
270
271
272
273
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Eagle3LlamaForCausalLM(LlamaForCausalLM):
274
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
275
        nn.Module.__init__(self)
276
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
277
278
279
280
281
        # Ensure draft_vocab_size is set
        # default to the base vocab size when absent
        if getattr(self.config, "draft_vocab_size", None) is None:
            base_vocab_size = getattr(self.config, "vocab_size", None)
            self.config.draft_vocab_size = base_vocab_size
282
        target_layer_num = vllm_config.model_config.get_num_layers(
283
284
            vllm_config.parallel_config
        )
285
286
287
288

        # Store target layer count in draft config for
        # proper layer_types indexing in draft models
        self.config.target_layer_count = target_layer_num
289
290
291
        self.model = LlamaModel(
            vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num
        )
292
293
294
295
296

        logit_scale = getattr(self.config, "logit_scale", 1.0)
        self.lm_head = ParallelLMHead(
            self.config.draft_vocab_size,
            self.config.hidden_size,
297
            quant_config=get_draft_quant_config(vllm_config),
298
299
300
301
302
            prefix=maybe_prefix(prefix, "lm_head"),
        )
        self.logits_processor = LogitsProcessor(
            self.config.draft_vocab_size, scale=logit_scale
        )
303
        self.draft_id_to_target_id = nn.Parameter(
304
            torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
305
306
307
            requires_grad=False,
        )

308
309
310
311
312
313
314
315
316
317
318
319
320
        self.use_parallel_drafting = vllm_config.speculative_config.parallel_drafting

        if self.use_parallel_drafting:
            self.register_buffer(
                "mask_hidden",
                torch.zeros(
                    1,
                    (3 if self.model.use_aux_hidden_state else 1)
                    * self.config.hidden_size,
                ),
                persistent=False,
            )

321
    def embed_input_ids(
322
323
        self,
        input_ids: torch.Tensor,
324
325
        multimodal_embeddings: NestedTensors | None = None,
        is_multimodal: torch.Tensor | None = None,
326
    ) -> torch.Tensor:
327
        return self.model.embed_input_ids(input_ids)
328

329
330
331
332
333
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
334
        inputs_embeds: torch.Tensor | None = None,
335
    ) -> tuple[torch.Tensor, torch.Tensor]:
336
        return self.model(input_ids, positions, hidden_states, inputs_embeds)
337
338
339
340

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
341
    ) -> torch.Tensor | None:
342
        logits = self.logits_processor(self.lm_head, hidden_states)
343
        if self.draft_id_to_target_id is None:
344
345
            assert logits.shape[1] == self.config.vocab_size, (
                "Expected logits to have shape "
346
                f"(*, {self.config.vocab_size}), but got {logits.shape}"
347
            )
348
349
            return logits

350
351
        base = torch.arange(self.config.draft_vocab_size, device=logits.device)
        targets = base + self.draft_id_to_target_id
352
353
354
355
356
357
358
        logits_new = logits.new_full(
            (
                logits.shape[0],
                self.config.vocab_size,
            ),
            float("-inf"),
        )
359
360
361
        logits_new[:, targets] = logits
        return logits_new

362
363
364
365
    def combine_hidden_states(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
366
367
        if not self.model.use_aux_hidden_state:
            return hidden_states
368
        # combine multiple auxiliary hidden states returned by eagle3
369
370
371

        if self.model.norm_before_fc:
            hidden_states = self.model.input_norm(hidden_states)
372
373
        return self.model.fc(hidden_states)

374
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
375
        model_weights = {}
376
        includes_draft_id_mapping = False
377
        includes_embed_tokens = False
378
        includes_mask_hidden = False
379
380
381
382
383
        for name, loaded_weight in weights:
            if "t2d" in name:
                continue
            if "d2t" in name:
                name = name.replace("d2t", "draft_id_to_target_id")
384
                includes_draft_id_mapping = True
385
386
387
388
389
390
391
392
393
394
395
396
            elif "mask_hidden" in name:
                # Load mask_hidden directly into buffer
                if not self.use_parallel_drafting:
                    logger.warning(
                        "mask_hidden found in weights but "
                        "model is not configured for parallel drafting. "
                        "Skipping loading mask_hidden."
                    )
                    continue
                self.mask_hidden.copy_(loaded_weight.view(1, -1))
                includes_mask_hidden = True
                continue
397
398
            elif "lm_head" not in name:
                name = "model." + name
399
400
            if "embed_tokens" in name:
                includes_embed_tokens = True
401
            model_weights[name] = loaded_weight
402
            process_eagle_weight(self, name)
403

404
405
406
407
408
409
410
411
        if not includes_mask_hidden and self.use_parallel_drafting:
            raise ValueError(
                "mask_hidden not found in weights but "
                "model is configured for parallel drafting. "
                "Please provide mask_hidden in the weights."
            )

        skip_substrs = ["mask_hidden"]
412
413
414
415
        if not includes_draft_id_mapping:
            skip_substrs.append("draft_id_to_target_id")
        if not includes_embed_tokens:
            skip_substrs.append("embed_tokens")
416
417
        if not self.model.use_aux_hidden_state:
            skip_substrs.append("fc.")
418
419
        if not self.model.norm_before_fc:
            skip_substrs.append("input_norm.")
420
421
422
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=None,
423
            skip_substrs=skip_substrs,
424
425
        )
        loader.load_weights(model_weights.items())