llama_eagle3.py 10.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Iterable
from typing import Optional
6
7
8
9
10

import torch
import torch.nn as nn
from transformers import LlamaConfig

11
12
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
                                              LlamaForCausalLM)
from vllm.v1.sample.metadata import SamplingMetadata

from .utils import AutoWeightsLoader, maybe_prefix

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):

    def __init__(
        self,
        config: LlamaConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__(config, quant_config=quant_config, prefix=prefix)

        # override qkv
        self.self_attn.qkv_proj = QKVParallelLinear(
            2 * self.hidden_size,
            self.self_attn.head_dim,
            self.self_attn.total_num_heads,
            self.self_attn.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "qkv_proj"),
        )

        self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        if getattr(config, "norm_before_residual", False):
            self._residual_norm = self._norm_before_residual
        else:
            self._residual_norm = self._norm_after_residual

    def _norm_before_residual(
            self,
            hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        hidden_states = self.hidden_norm(hidden_states)
        residual = hidden_states
        return hidden_states, residual

    def _norm_after_residual(
            self,
            hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        residual = hidden_states
        hidden_states = self.hidden_norm(hidden_states)
        return hidden_states, residual

73
74
75
76
77
78
    def forward(
        self,
        positions: torch.Tensor,
        embeds: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
79
    ) -> tuple[torch.Tensor, torch.Tensor]:
80
81

        embeds = self.input_layernorm(embeds)
82
83
84

        hidden_states, residual = self._residual_norm(
            hidden_states=hidden_states)
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

        hidden_states = torch.cat([embeds, hidden_states], dim=-1)
        # Self Attention
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)

        # Fully Connected
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


102
@support_torch_compile
103
104
105
106
107
class LlamaModel(nn.Module):

    def __init__(
        self,
        *,
108
        vllm_config: VllmConfig,
109
110
111
112
        start_layer_id: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()
113
114
        self.config = vllm_config. \
            speculative_config.draft_model_config.hf_config
115
        self.vocab_size = self.config.vocab_size
116

117
118
119
120
121
        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "embed_tokens"),
        )
122

123
124
        self.layers = nn.ModuleList([
            LlamaDecoderLayer(
125
                config=self.config,
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
                prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
            )
        ])
        if hasattr(self.config, "target_hidden_size"):
            self.fc = torch.nn.Linear(self.config.target_hidden_size * 3,
                                      self.config.hidden_size,
                                      bias=False)
        else:
            self.fc = torch.nn.Linear(self.config.hidden_size * 3,
                                      self.config.hidden_size,
                                      bias=False)
        self.norm = RMSNorm(
            self.config.hidden_size,
            eps=self.config.rms_norm_eps,
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
147
    ) -> tuple[torch.Tensor, torch.Tensor]:
148
        input_embeds = self.embed_tokens(input_ids)
149
        assert hidden_states.shape[-1] == input_embeds.shape[-1]
150
151
152
153
154
155
156
157
158
159
160
161

        residual = None
        hidden_states, residual = self.layers[0](
            positions,
            input_embeds,
            hidden_states,
            residual,
        )

        hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
        return hidden_states, hidden_prenorm

162
163
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
164
165
166
167
168
169
170
171
172
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
173
        loaded_params: set[str] = set()
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        for name, loaded_weight in weights:
            if 'midlayer.' in name:
                name = name.replace('midlayer.', 'layers.0.')
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Eagle3LlamaForCausalLM(LlamaForCausalLM):

196
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
197
        nn.Module.__init__(self)
198
199
        self.config = vllm_config. \
            speculative_config.draft_model_config.hf_config
200
201
        target_layer_num = vllm_config.model_config.get_num_layers(
            vllm_config.parallel_config)
202
        self.model = LlamaModel(vllm_config=vllm_config,
203
204
                                prefix="model",
                                start_layer_id=target_layer_num)
205
206
207
208
209
210
211
212
213
214
215

        logit_scale = getattr(self.config, "logit_scale", 1.0)
        self.lm_head = ParallelLMHead(
            self.config.draft_vocab_size,
            self.config.hidden_size,
            org_num_embeddings=self.config.draft_vocab_size,
            padding_size=(DEFAULT_VOCAB_PADDING_SIZE),
            prefix="")
        self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
                                                scale=logit_scale)
        self.draft_id_to_target_id = nn.Parameter(
216
            torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
217
218
219
220
221
222
223
224
            requires_grad=False,
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
225
        inputs_embeds: Optional[torch.Tensor] = None,
226
    ) -> tuple[torch.Tensor, torch.Tensor]:
227
228
229
230
        if inputs_embeds is not None:
            raise NotImplementedError(
                f"{type(self).__name__} does not support multimodal inputs yet."
            )
231
232
233
234
235
236
237
238
239
        return self.model(input_ids, positions, hidden_states)

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
240
        if self.draft_id_to_target_id is None:
241
242
243
            assert logits.shape[1] == self.config.vocab_size, \
                "Expected logits to have shape " \
                f"(*, {self.config.vocab_size}), but got {logits.shape}"
244
245
            return logits

246
247
248
249
250
251
252
253
254
        base = torch.arange(self.config.draft_vocab_size, device=logits.device)
        targets = base + self.draft_id_to_target_id
        logits_new = logits.new_full((
            logits.shape[0],
            self.config.vocab_size,
        ), float('-inf'))
        logits_new[:, targets] = logits
        return logits_new

255
256
257
258
259
260
261
    def combine_hidden_states(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # combine multiple auxiliary hidden states returned by eagle3
        return self.model.fc(hidden_states)

262
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
263
        model_weights = {}
264
        includes_draft_id_mapping = False
265
        includes_embed_tokens = False
266
267
268
269
270
        for name, loaded_weight in weights:
            if "t2d" in name:
                continue
            if "d2t" in name:
                name = name.replace("d2t", "draft_id_to_target_id")
271
                includes_draft_id_mapping = True
272
273
            elif "lm_head" not in name:
                name = "model." + name
274
275
            if "embed_tokens" in name:
                includes_embed_tokens = True
276
277
            model_weights[name] = loaded_weight

278
279
280
281
282
        skip_substrs = []
        if not includes_draft_id_mapping:
            skip_substrs.append("draft_id_to_target_id")
        if not includes_embed_tokens:
            skip_substrs.append("embed_tokens")
283
284
285
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=None,
286
            skip_substrs=skip_substrs,
287
288
        )
        loader.load_weights(model_weights.items())