llama_eagle3.py 12.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
6
7
8
9

import torch
import torch.nn as nn
from transformers import LlamaConfig

10
from vllm.compilation.decorators import support_torch_compile
11
from vllm.config import VllmConfig, get_current_vllm_config
12
13
14
15
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
17
from vllm.model_executor.layers.vocab_parallel_embedding import (
18
19
20
    ParallelLMHead,
    VocabParallelEmbedding,
)
21
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
23
from vllm.multimodal import MULTIMODAL_REGISTRY
24
from vllm.multimodal.inputs import NestedTensors
25

26
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
27
28
29
30
31

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):
32
33
34
35
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
36
        config: LlamaConfig | None = None,
37
        layer_idx: int = 0,
38
    ) -> None:
39
40
41
        super().__init__(vllm_config, prefix=prefix, config=config)

        config = config or vllm_config.model_config.hf_config
42
        quant_config = self.get_quant_config(vllm_config)
43

44
45
46
47
        # First layer uses 2*hidden_size (embeds + hidden_states concatenated)
        # Subsequent layers use hidden_size (only hidden_states, no embeds)
        qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size

48
49
        # override qkv
        self.self_attn.qkv_proj = QKVParallelLinear(
50
            qkv_input_size,
51
52
53
54
55
56
57
58
59
            self.self_attn.head_dim,
            self.self_attn.total_num_heads,
            self.self_attn.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "qkv_proj"),
        )

        self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
60
        self.layer_idx = layer_idx
61

62
63
64
65
66
        if getattr(config, "norm_before_residual", False):
            self._residual_norm = self._norm_before_residual
        else:
            self._residual_norm = self._norm_after_residual

67
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
68
69
70
71
        """Use drafter's quantization config instead of verifier's."""
        draft_model_config = vllm_config.speculative_config.draft_model_config
        draft_load_config = vllm_config.load_config

72
73
74
75
76
        return (
            VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
            if draft_model_config
            else None
        )
77

78
    def _norm_before_residual(
79
80
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
81
82
83
84
85
        hidden_states = self.hidden_norm(hidden_states)
        residual = hidden_states
        return hidden_states, residual

    def _norm_after_residual(
86
87
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
88
89
90
91
        residual = hidden_states
        hidden_states = self.hidden_norm(hidden_states)
        return hidden_states, residual

92
93
94
95
96
    def forward(
        self,
        positions: torch.Tensor,
        embeds: torch.Tensor,
        hidden_states: torch.Tensor,
97
        residual: torch.Tensor | None,
98
    ) -> tuple[torch.Tensor, torch.Tensor]:
99
100
101
102
103
104
105
106
        if self.layer_idx == 0:
            # First layer: concatenate embeds with hidden_states
            embeds = self.input_layernorm(embeds)
            hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
            hidden_states = torch.cat([embeds, hidden_states], dim=-1)
        else:
            # Subsequent layers: process hidden_states and residuals only
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
107
108
109
110
111
112
113

        # Self Attention
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

114
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
115
116
117
118
119
120
121

        # Fully Connected
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


122
123
124
125
126
127
128
129
130
@support_torch_compile(
    # torch.compile is disabled for multimodal EAGLE3 models due to constraint
    # violations with dynamic shapes during tensor concatenation operations.
    # See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132
    # Non-multimodal EAGLE3 models can still use torch.compile safely.
    enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs(
        vllm_config.model_config
    ),
)
131
132
133
134
class LlamaModel(nn.Module):
    def __init__(
        self,
        *,
135
        vllm_config: VllmConfig,
136
137
138
139
        start_layer_id: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()
140
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
141
        self.vocab_size = self.config.vocab_size
142

143
144
        current_vllm_config = get_current_vllm_config()

145
146
147
148
149
        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "embed_tokens"),
        )
150

151
152
153
154
        self.layers = nn.ModuleList(
            [
                LlamaDecoderLayer(
                    current_vllm_config,
155
                    prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
156
                    config=self.config,
157
                    layer_idx=layer_idx,
158
                )
159
                for layer_idx in range(self.config.num_hidden_layers)
160
161
            ]
        )
162
        if hasattr(self.config, "target_hidden_size"):
163
164
165
            self.fc = torch.nn.Linear(
                self.config.target_hidden_size * 3, self.config.hidden_size, bias=False
            )
166
        else:
167
168
169
            self.fc = torch.nn.Linear(
                self.config.hidden_size * 3, self.config.hidden_size, bias=False
            )
170
171
172
173
174
        self.norm = RMSNorm(
            self.config.hidden_size,
            eps=self.config.rms_norm_eps,
        )

175
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
176
177
        return self.embed_tokens(input_ids)

178
179
180
181
182
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
183
        input_embeds: torch.Tensor | None = None,
184
    ) -> tuple[torch.Tensor, torch.Tensor]:
185
        if input_embeds is None:
186
            input_embeds = self.embed_input_ids(input_ids)
187
        assert hidden_states.shape[-1] == input_embeds.shape[-1]
188
189

        residual = None
190
191
192
193
194
195
196
        for layer in self.layers:
            hidden_states, residual = layer(
                positions=positions,
                embeds=input_embeds,
                hidden_states=hidden_states,
                residual=residual,
            )
197
198
199
        hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
        return hidden_states, hidden_prenorm

200
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
201
202
203
204
205
206
207
208
209
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
210
        loaded_params: set[str] = set()
211
        for name, loaded_weight in weights:
212
213
            if "midlayer." in name:
                name = name.replace("midlayer.", "layers.0.")
214
215
216
217
218
219
220
221
222
223
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
224
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
225
226
227
228
229
230
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Eagle3LlamaForCausalLM(LlamaForCausalLM):
231
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
232
        nn.Module.__init__(self)
233
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
234
235
236
237
238
        # Ensure draft_vocab_size is set
        # default to the base vocab size when absent
        if getattr(self.config, "draft_vocab_size", None) is None:
            base_vocab_size = getattr(self.config, "vocab_size", None)
            self.config.draft_vocab_size = base_vocab_size
239
        target_layer_num = vllm_config.model_config.get_num_layers(
240
241
            vllm_config.parallel_config
        )
242
243
244
245

        # Store target layer count in draft config for
        # proper layer_types indexing in draft models
        self.config.target_layer_count = target_layer_num
246
247
248
        self.model = LlamaModel(
            vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num
        )
249
250
251
252
253

        logit_scale = getattr(self.config, "logit_scale", 1.0)
        self.lm_head = ParallelLMHead(
            self.config.draft_vocab_size,
            self.config.hidden_size,
254
255
256
257
258
            prefix=maybe_prefix(prefix, "lm_head"),
        )
        self.logits_processor = LogitsProcessor(
            self.config.draft_vocab_size, scale=logit_scale
        )
259
        self.draft_id_to_target_id = nn.Parameter(
260
            torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
261
262
263
            requires_grad=False,
        )

264
    def embed_input_ids(
265
266
        self,
        input_ids: torch.Tensor,
267
268
        multimodal_embeddings: NestedTensors | None = None,
        is_multimodal: torch.Tensor | None = None,
269
    ) -> torch.Tensor:
270
        return self.model.embed_input_ids(input_ids)
271

272
273
274
275
276
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
277
        inputs_embeds: torch.Tensor | None = None,
278
    ) -> tuple[torch.Tensor, torch.Tensor]:
279
        return self.model(input_ids, positions, hidden_states, inputs_embeds)
280
281
282
283

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
284
    ) -> torch.Tensor | None:
285
        logits = self.logits_processor(self.lm_head, hidden_states)
286
        if self.draft_id_to_target_id is None:
287
288
            assert logits.shape[1] == self.config.vocab_size, (
                "Expected logits to have shape "
289
                f"(*, {self.config.vocab_size}), but got {logits.shape}"
290
            )
291
292
            return logits

293
294
        base = torch.arange(self.config.draft_vocab_size, device=logits.device)
        targets = base + self.draft_id_to_target_id
295
296
297
298
299
300
301
        logits_new = logits.new_full(
            (
                logits.shape[0],
                self.config.vocab_size,
            ),
            float("-inf"),
        )
302
303
304
        logits_new[:, targets] = logits
        return logits_new

305
306
307
308
309
310
311
    def combine_hidden_states(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # combine multiple auxiliary hidden states returned by eagle3
        return self.model.fc(hidden_states)

312
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
313
        model_weights = {}
314
        includes_draft_id_mapping = False
315
        includes_embed_tokens = False
316
317
318
319
320
        for name, loaded_weight in weights:
            if "t2d" in name:
                continue
            if "d2t" in name:
                name = name.replace("d2t", "draft_id_to_target_id")
321
                includes_draft_id_mapping = True
322
323
            elif "lm_head" not in name:
                name = "model." + name
324
325
            if "embed_tokens" in name:
                includes_embed_tokens = True
326
            model_weights[name] = loaded_weight
327
            process_eagle_weight(self, name)
328

329
330
331
332
333
        skip_substrs = []
        if not includes_draft_id_mapping:
            skip_substrs.append("draft_id_to_target_id")
        if not includes_embed_tokens:
            skip_substrs.append("embed_tokens")
334
335
336
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=None,
337
            skip_substrs=skip_substrs,
338
339
        )
        loader.load_weights(model_weights.items())