llama_eagle3.py 10.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Iterable
from typing import Optional
6
7
8
9
10

import torch
import torch.nn as nn
from transformers import LlamaConfig

11
from vllm.compilation.decorators import support_torch_compile
12
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
                                              LlamaForCausalLM)

from .utils import AutoWeightsLoader, maybe_prefix

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):

    def __init__(
        self,
        config: LlamaConfig,
35
        cache_config: Optional[CacheConfig] = None,
36
37
38
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
39
40
41
42
        super().__init__(config,
                         cache_config=cache_config,
                         quant_config=quant_config,
                         prefix=prefix)
43
44
45
46
47
48
49
50
51
52
53
54
55
56

        # override qkv
        self.self_attn.qkv_proj = QKVParallelLinear(
            2 * self.hidden_size,
            self.self_attn.head_dim,
            self.self_attn.total_num_heads,
            self.self_attn.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "qkv_proj"),
        )

        self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        if getattr(config, "norm_before_residual", False):
            self._residual_norm = self._norm_before_residual
        else:
            self._residual_norm = self._norm_after_residual

    def _norm_before_residual(
            self,
            hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        hidden_states = self.hidden_norm(hidden_states)
        residual = hidden_states
        return hidden_states, residual

    def _norm_after_residual(
            self,
            hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        residual = hidden_states
        hidden_states = self.hidden_norm(hidden_states)
        return hidden_states, residual

76
77
78
79
80
81
    def forward(
        self,
        positions: torch.Tensor,
        embeds: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
82
    ) -> tuple[torch.Tensor, torch.Tensor]:
83
84

        embeds = self.input_layernorm(embeds)
85
86
87

        hidden_states, residual = self._residual_norm(
            hidden_states=hidden_states)
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

        hidden_states = torch.cat([embeds, hidden_states], dim=-1)
        # Self Attention
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)

        # Fully Connected
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


105
@support_torch_compile
106
107
108
109
110
class LlamaModel(nn.Module):

    def __init__(
        self,
        *,
111
        vllm_config: VllmConfig,
112
113
114
115
        start_layer_id: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()
116
117
        self.config = vllm_config. \
            speculative_config.draft_model_config.hf_config
118
        self.vocab_size = self.config.vocab_size
119

120
121
        current_vllm_config = get_current_vllm_config()

122
123
124
125
126
        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "embed_tokens"),
        )
127

128
129
        self.layers = nn.ModuleList([
            LlamaDecoderLayer(
130
                config=self.config,
131
                cache_config=current_vllm_config.cache_config,
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
                prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
            )
        ])
        if hasattr(self.config, "target_hidden_size"):
            self.fc = torch.nn.Linear(self.config.target_hidden_size * 3,
                                      self.config.hidden_size,
                                      bias=False)
        else:
            self.fc = torch.nn.Linear(self.config.hidden_size * 3,
                                      self.config.hidden_size,
                                      bias=False)
        self.norm = RMSNorm(
            self.config.hidden_size,
            eps=self.config.rms_norm_eps,
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
153
    ) -> tuple[torch.Tensor, torch.Tensor]:
154
        input_embeds = self.embed_tokens(input_ids)
155
        assert hidden_states.shape[-1] == input_embeds.shape[-1]
156
157
158
159
160
161
162
163
164
165
166
167

        residual = None
        hidden_states, residual = self.layers[0](
            positions,
            input_embeds,
            hidden_states,
            residual,
        )

        hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
        return hidden_states, hidden_prenorm

168
169
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
170
171
172
173
174
175
176
177
178
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
179
        loaded_params: set[str] = set()
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        for name, loaded_weight in weights:
            if 'midlayer.' in name:
                name = name.replace('midlayer.', 'layers.0.')
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Eagle3LlamaForCausalLM(LlamaForCausalLM):

202
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
203
        nn.Module.__init__(self)
204
205
        self.config = vllm_config. \
            speculative_config.draft_model_config.hf_config
206
207
        target_layer_num = vllm_config.model_config.get_num_layers(
            vllm_config.parallel_config)
208
209
210
211

        # Store target layer count in draft config for
        # proper layer_types indexing in draft models
        self.config.target_layer_count = target_layer_num
212
        self.model = LlamaModel(vllm_config=vllm_config,
213
214
                                prefix="model",
                                start_layer_id=target_layer_num)
215
216
217
218
219
220
221

        logit_scale = getattr(self.config, "logit_scale", 1.0)
        self.lm_head = ParallelLMHead(
            self.config.draft_vocab_size,
            self.config.hidden_size,
            org_num_embeddings=self.config.draft_vocab_size,
            padding_size=(DEFAULT_VOCAB_PADDING_SIZE),
222
            prefix=maybe_prefix(prefix, "lm_head"))
223
224
225
        self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
                                                scale=logit_scale)
        self.draft_id_to_target_id = nn.Parameter(
226
            torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
227
228
229
230
231
232
233
234
            requires_grad=False,
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
235
        inputs_embeds: Optional[torch.Tensor] = None,
236
    ) -> tuple[torch.Tensor, torch.Tensor]:
237
238
239
240
        if inputs_embeds is not None:
            raise NotImplementedError(
                f"{type(self).__name__} does not support multimodal inputs yet."
            )
241
242
243
244
245
246
        return self.model(input_ids, positions, hidden_states)

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
247
        logits = self.logits_processor(self.lm_head, hidden_states)
248
        if self.draft_id_to_target_id is None:
249
250
251
            assert logits.shape[1] == self.config.vocab_size, \
                "Expected logits to have shape " \
                f"(*, {self.config.vocab_size}), but got {logits.shape}"
252
253
            return logits

254
255
256
257
258
259
260
261
262
        base = torch.arange(self.config.draft_vocab_size, device=logits.device)
        targets = base + self.draft_id_to_target_id
        logits_new = logits.new_full((
            logits.shape[0],
            self.config.vocab_size,
        ), float('-inf'))
        logits_new[:, targets] = logits
        return logits_new

263
264
265
266
267
268
269
    def combine_hidden_states(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # combine multiple auxiliary hidden states returned by eagle3
        return self.model.fc(hidden_states)

270
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
271
        model_weights = {}
272
        includes_draft_id_mapping = False
273
        includes_embed_tokens = False
274
275
276
277
278
        for name, loaded_weight in weights:
            if "t2d" in name:
                continue
            if "d2t" in name:
                name = name.replace("d2t", "draft_id_to_target_id")
279
                includes_draft_id_mapping = True
280
281
            elif "lm_head" not in name:
                name = "model." + name
282
283
            if "embed_tokens" in name:
                includes_embed_tokens = True
284
285
            model_weights[name] = loaded_weight

286
287
288
289
290
        skip_substrs = []
        if not includes_draft_id_mapping:
            skip_substrs.append("draft_id_to_target_id")
        if not includes_embed_tokens:
            skip_substrs.append("embed_tokens")
291
292
293
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=None,
294
            skip_substrs=skip_substrs,
295
296
        )
        loader.load_weights(model_weights.items())