llama_eagle3.py 13.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
6
7
8
9

import torch
import torch.nn as nn
from transformers import LlamaConfig

10
from vllm.compilation.decorators import support_torch_compile
11
from vllm.config import VllmConfig, get_current_vllm_config
12
13
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
14
from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear
15
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
17
from vllm.model_executor.layers.vocab_parallel_embedding import (
18
19
20
    ParallelLMHead,
    VocabParallelEmbedding,
)
21
22
23
24
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
25
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.inputs import NestedTensors
28

29
30
31
32
33
34
from .utils import (
    AutoWeightsLoader,
    get_draft_quant_config,
    maybe_prefix,
    process_eagle_weight,
)
35
36
37
38
39

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):
40
41
42
43
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
44
        config: LlamaConfig | None = None,
45
        layer_idx: int = 0,
46
    ) -> None:
47
48
49
        super().__init__(vllm_config, prefix=prefix, config=config)

        config = config or vllm_config.model_config.hf_config
50
        quant_config = self.get_quant_config(vllm_config)
51

52
53
54
55
        # First layer uses 2*hidden_size (embeds + hidden_states concatenated)
        # Subsequent layers use hidden_size (only hidden_states, no embeds)
        qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size

56
57
        # override qkv
        self.self_attn.qkv_proj = QKVParallelLinear(
58
            qkv_input_size,
59
60
61
62
63
64
65
66
67
            self.self_attn.head_dim,
            self.self_attn.total_num_heads,
            self.self_attn.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "qkv_proj"),
        )

        self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
68
        self.layer_idx = layer_idx
69

70
71
72
73
74
        if getattr(config, "norm_before_residual", False):
            self._residual_norm = self._norm_before_residual
        else:
            self._residual_norm = self._norm_after_residual

75
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
76
        """Use drafter's quantization config instead of verifier's."""
77
        return get_draft_quant_config(vllm_config)
78

79
    def _norm_before_residual(
80
81
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
82
83
84
85
86
        hidden_states = self.hidden_norm(hidden_states)
        residual = hidden_states
        return hidden_states, residual

    def _norm_after_residual(
87
88
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
89
90
91
92
        residual = hidden_states
        hidden_states = self.hidden_norm(hidden_states)
        return hidden_states, residual

93
94
95
96
97
    def forward(
        self,
        positions: torch.Tensor,
        embeds: torch.Tensor,
        hidden_states: torch.Tensor,
98
        residual: torch.Tensor | None,
99
    ) -> tuple[torch.Tensor, torch.Tensor]:
100
101
102
103
104
105
106
107
        if self.layer_idx == 0:
            # First layer: concatenate embeds with hidden_states
            embeds = self.input_layernorm(embeds)
            hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
            hidden_states = torch.cat([embeds, hidden_states], dim=-1)
        else:
            # Subsequent layers: process hidden_states and residuals only
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
108
109
110
111
112
113
114

        # Self Attention
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

115
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
116
117
118
119
120
121
122

        # Fully Connected
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


123
124
125
126
127
128
129
130
131
@support_torch_compile(
    # torch.compile is disabled for multimodal EAGLE3 models due to constraint
    # violations with dynamic shapes during tensor concatenation operations.
    # See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132
    # Non-multimodal EAGLE3 models can still use torch.compile safely.
    enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs(
        vllm_config.model_config
    ),
)
132
133
134
135
class LlamaModel(nn.Module):
    def __init__(
        self,
        *,
136
        vllm_config: VllmConfig,
137
138
139
140
        start_layer_id: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()
141
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
142
        self.vocab_size = self.config.vocab_size
143

144
145
146
        # Get drafter's quantization config
        self.quant_config = get_draft_quant_config(vllm_config)

147
148
        current_vllm_config = get_current_vllm_config()

149
150
151
152
153
        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "embed_tokens"),
        )
154

155
156
157
158
        self.layers = nn.ModuleList(
            [
                LlamaDecoderLayer(
                    current_vllm_config,
159
                    prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
160
                    config=self.config,
161
                    layer_idx=layer_idx,
162
                )
163
                for layer_idx in range(self.config.num_hidden_layers)
164
165
            ]
        )
166
        if hasattr(self.config, "target_hidden_size"):
167
            fc_input_size = self.config.target_hidden_size * 3
168
        else:
169
170
171
172
173
174
175
176
177
178
179
            fc_input_size = self.config.hidden_size * 3
        self.fc = ReplicatedLinear(
            input_size=fc_input_size,
            output_size=self.config.hidden_size,
            bias=False,
            params_dtype=vllm_config.model_config.dtype,
            quant_config=self.quant_config,
            prefix=maybe_prefix(prefix, "fc"),
            return_bias=False,
        )

180
181
182
183
184
        self.norm = RMSNorm(
            self.config.hidden_size,
            eps=self.config.rms_norm_eps,
        )

185
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
186
187
        return self.embed_tokens(input_ids)

188
189
190
191
192
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
193
        input_embeds: torch.Tensor | None = None,
194
    ) -> tuple[torch.Tensor, torch.Tensor]:
195
        if input_embeds is None:
196
            input_embeds = self.embed_input_ids(input_ids)
197
        assert hidden_states.shape[-1] == input_embeds.shape[-1]
198
199

        residual = None
200
201
202
203
204
205
206
        for layer in self.layers:
            hidden_states, residual = layer(
                positions=positions,
                embeds=input_embeds,
                hidden_states=hidden_states,
                residual=residual,
            )
207
208
209
        hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
        return hidden_states, hidden_prenorm

210
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
211
212
213
214
215
216
217
218
219
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
220
        loaded_params: set[str] = set()
221
        for name, loaded_weight in weights:
222
223
            if "midlayer." in name:
                name = name.replace("midlayer.", "layers.0.")
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            # Handle kv cache quantization scales
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            # Remapping the name FP8 kv-scale
            if "scale" in name:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
242
243
244
245
246
247
248
249
250
251
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
252
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
253
254
255
256
257
258
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Eagle3LlamaForCausalLM(LlamaForCausalLM):
259
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
260
        nn.Module.__init__(self)
261
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
262
263
264
265
266
        # Ensure draft_vocab_size is set
        # default to the base vocab size when absent
        if getattr(self.config, "draft_vocab_size", None) is None:
            base_vocab_size = getattr(self.config, "vocab_size", None)
            self.config.draft_vocab_size = base_vocab_size
267
        target_layer_num = vllm_config.model_config.get_num_layers(
268
269
            vllm_config.parallel_config
        )
270
271
272
273

        # Store target layer count in draft config for
        # proper layer_types indexing in draft models
        self.config.target_layer_count = target_layer_num
274
275
276
        self.model = LlamaModel(
            vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num
        )
277
278
279
280
281

        logit_scale = getattr(self.config, "logit_scale", 1.0)
        self.lm_head = ParallelLMHead(
            self.config.draft_vocab_size,
            self.config.hidden_size,
282
283
284
285
286
            prefix=maybe_prefix(prefix, "lm_head"),
        )
        self.logits_processor = LogitsProcessor(
            self.config.draft_vocab_size, scale=logit_scale
        )
287
        self.draft_id_to_target_id = nn.Parameter(
288
            torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
289
290
291
            requires_grad=False,
        )

292
    def embed_input_ids(
293
294
        self,
        input_ids: torch.Tensor,
295
296
        multimodal_embeddings: NestedTensors | None = None,
        is_multimodal: torch.Tensor | None = None,
297
    ) -> torch.Tensor:
298
        return self.model.embed_input_ids(input_ids)
299

300
301
302
303
304
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
305
        inputs_embeds: torch.Tensor | None = None,
306
    ) -> tuple[torch.Tensor, torch.Tensor]:
307
        return self.model(input_ids, positions, hidden_states, inputs_embeds)
308
309
310
311

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
312
    ) -> torch.Tensor | None:
313
        logits = self.logits_processor(self.lm_head, hidden_states)
314
        if self.draft_id_to_target_id is None:
315
316
            assert logits.shape[1] == self.config.vocab_size, (
                "Expected logits to have shape "
317
                f"(*, {self.config.vocab_size}), but got {logits.shape}"
318
            )
319
320
            return logits

321
322
        base = torch.arange(self.config.draft_vocab_size, device=logits.device)
        targets = base + self.draft_id_to_target_id
323
324
325
326
327
328
329
        logits_new = logits.new_full(
            (
                logits.shape[0],
                self.config.vocab_size,
            ),
            float("-inf"),
        )
330
331
332
        logits_new[:, targets] = logits
        return logits_new

333
334
335
336
337
338
339
    def combine_hidden_states(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # combine multiple auxiliary hidden states returned by eagle3
        return self.model.fc(hidden_states)

340
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
341
        model_weights = {}
342
        includes_draft_id_mapping = False
343
        includes_embed_tokens = False
344
345
346
347
348
        for name, loaded_weight in weights:
            if "t2d" in name:
                continue
            if "d2t" in name:
                name = name.replace("d2t", "draft_id_to_target_id")
349
                includes_draft_id_mapping = True
350
351
            elif "lm_head" not in name:
                name = "model." + name
352
353
            if "embed_tokens" in name:
                includes_embed_tokens = True
354
            model_weights[name] = loaded_weight
355
            process_eagle_weight(self, name)
356

357
358
359
360
361
        skip_substrs = []
        if not includes_draft_id_mapping:
            skip_substrs.append("draft_id_to_target_id")
        if not includes_embed_tokens:
            skip_substrs.append("embed_tokens")
362
363
364
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=None,
365
            skip_substrs=skip_substrs,
366
367
        )
        loader.load_weights(model_weights.items())