llama_eagle3.py 8.69 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
from collections.abc import Iterable
from typing import Optional
5
6
7
8
9

import torch
import torch.nn as nn
from transformers import LlamaConfig

10
11
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
12
from vllm.distributed.parallel_state import get_pp_group
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
                                              LlamaForCausalLM)
from vllm.v1.sample.metadata import SamplingMetadata

from .utils import AutoWeightsLoader, maybe_prefix

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):

    def __init__(
        self,
        config: LlamaConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__(config, quant_config=quant_config, prefix=prefix)

        # override qkv
        self.self_attn.qkv_proj = QKVParallelLinear(
            2 * self.hidden_size,
            self.self_attn.head_dim,
            self.self_attn.total_num_heads,
            self.self_attn.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "qkv_proj"),
        )

        self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        embeds: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
60
    ) -> tuple[torch.Tensor, torch.Tensor]:
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

        residual = hidden_states
        embeds = self.input_layernorm(embeds)
        hidden_states = self.hidden_norm(hidden_states)

        hidden_states = torch.cat([embeds, hidden_states], dim=-1)
        # Self Attention
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)

        # Fully Connected
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


82
@support_torch_compile
83
84
85
86
87
class LlamaModel(nn.Module):

    def __init__(
        self,
        *,
88
        vllm_config: VllmConfig,
89
90
91
92
        start_layer_id: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()
93
94
        self.config = vllm_config. \
            speculative_config.draft_model_config.hf_config
95
        self.vocab_size = self.config.vocab_size
96
97
98
99
100
101
102
103
104

        # if PP disabled then draft will share embed with target
        if get_pp_group().world_size > 1:
            self.embed_tokens = VocabParallelEmbedding(
                self.config.vocab_size,
                self.config.hidden_size,
                prefix=maybe_prefix(prefix, "embed_tokens"),
            )

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        self.layers = nn.ModuleList([
            LlamaDecoderLayer(
                self.config,
                prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
            )
        ])
        if hasattr(self.config, "target_hidden_size"):
            self.fc = torch.nn.Linear(self.config.target_hidden_size * 3,
                                      self.config.hidden_size,
                                      bias=False)
        else:
            self.fc = torch.nn.Linear(self.config.hidden_size * 3,
                                      self.config.hidden_size,
                                      bias=False)
        self.norm = RMSNorm(
            self.config.hidden_size,
            eps=self.config.rms_norm_eps,
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
129
    ) -> tuple[torch.Tensor, torch.Tensor]:
130
        input_embeds = self.embed_tokens(input_ids)
131
        assert hidden_states.shape[-1] == input_embeds.shape[-1]
132
133
134
135
136
137
138
139
140
141
142
143

        residual = None
        hidden_states, residual = self.layers[0](
            positions,
            input_embeds,
            hidden_states,
            residual,
        )

        hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
        return hidden_states, hidden_prenorm

144
145
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
146
147
148
149
150
151
152
153
154
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
155
        loaded_params: set[str] = set()
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        for name, loaded_weight in weights:
            if 'midlayer.' in name:
                name = name.replace('midlayer.', 'layers.0.')
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Eagle3LlamaForCausalLM(LlamaForCausalLM):

178
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
179
        nn.Module.__init__(self)
180
181
        self.config = vllm_config. \
            speculative_config.draft_model_config.hf_config
182
183
        target_layer_num = vllm_config.model_config.get_num_layers(
            vllm_config.parallel_config)
184
        self.model = LlamaModel(vllm_config=vllm_config,
185
186
                                prefix="model",
                                start_layer_id=target_layer_num)
187
188
189
190
191
192
193
194
195
196
197

        logit_scale = getattr(self.config, "logit_scale", 1.0)
        self.lm_head = ParallelLMHead(
            self.config.draft_vocab_size,
            self.config.hidden_size,
            org_num_embeddings=self.config.draft_vocab_size,
            padding_size=(DEFAULT_VOCAB_PADDING_SIZE),
            prefix="")
        self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
                                                scale=logit_scale)
        self.draft_id_to_target_id = nn.Parameter(
198
            torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
199
200
201
202
203
204
205
206
            requires_grad=False,
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
207
    ) -> tuple[torch.Tensor, torch.Tensor]:
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        return self.model(input_ids, positions, hidden_states)

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        base = torch.arange(self.config.draft_vocab_size, device=logits.device)
        targets = base + self.draft_id_to_target_id
        logits_new = logits.new_full((
            logits.shape[0],
            self.config.vocab_size,
        ), float('-inf'))
        logits_new[:, targets] = logits
        return logits_new

226
227
228
229
230
231
232
    def combine_hidden_states(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # combine multiple auxiliary hidden states returned by eagle3
        return self.model.fc(hidden_states)

233
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=None,
        )

        model_weights = {}
        for name, loaded_weight in weights:
            if "t2d" in name:
                continue
            if "d2t" in name:
                name = name.replace("d2t", "draft_id_to_target_id")
            elif "lm_head" not in name:
                name = "model." + name
            model_weights[name] = loaded_weight

        return loader.load_weights(model_weights.items())