minicpm_eagle.py 14.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only EagleMiniCPM model compatible with HuggingFace weights."""
26

27
28
29
30
31
32
33
34
35
36
37
38
39
import math
from collections.abc import Iterable

import torch
from torch import nn
from transformers import PretrainedConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
40
41
42
    ParallelLMHead,
    VocabParallelEmbedding,
)
43
44
45
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors

46
from .interfaces import SupportsEagle, SupportsLoRA, SupportsPP
47
48
49
from .minicpm import MiniCPMAttention as EagleMiniCPMAttention
from .minicpm import MiniCPMMLP as EagleMiniCPMMLP
from .minicpm import MiniCPMMoE as EagleMiniCPMMoE
50
51
52
53
54
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    maybe_prefix,
55
    process_eagle_weight,
56
)
57
58
59
60
61
62


class EagleMiniCPMDecoderLayer(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
63
64
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
65
66
67
68
69
70
71
72
73
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config
        self.cache_config = cache_config
        self.quant_config = quant_config
        self.hidden_size = config.hidden_size
        self.rope_theta = getattr(config, "rope_theta", 10000)
        self.rope_scaling = getattr(config, "rope_scaling", None)
74
        self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
75
76
77
78
79
        self.prefix = prefix
        self._init_attn_block()
        self._init_ffn_block()

    def _init_attn_block(self):
80
81
82
        self.input_layernorm = RMSNorm(
            self.config.hidden_size, eps=self.config.rms_norm_eps
        )
83
84
85
86
87
88
89
90
91
92
93
94
95
        self.self_attn = EagleMiniCPMAttention(
            hidden_size=self.hidden_size,
            num_heads=self.config.num_attention_heads,
            num_kv_heads=self.config.num_key_value_heads,
            rope_theta=self.rope_theta,
            rope_scaling=self.rope_scaling,
            max_position_embeddings=self.max_position_embeddings,
            cache_config=self.cache_config,
            quant_config=self.quant_config,
            prefix=f"{self.prefix}.self_attn",
        )

    def _init_ffn_block(self):
96
97
98
        self.post_attention_layernorm = RMSNorm(
            self.config.hidden_size, eps=self.config.rms_norm_eps
        )
99
100
101
102
103
104
        self.num_experts = getattr(self.config, "num_experts", 0)
        if self.num_experts == 0:
            self.mlp = EagleMiniCPMMLP(
                hidden_size=self.hidden_size,
                intermediate_size=self.config.intermediate_size,
                hidden_act=self.config.hidden_act,
105
                hidden_act_param=getattr(self.config, "hidden_act_param", 0.0),
106
107
108
109
110
111
112
                quant_config=self.quant_config,
            )
        else:
            self.mlp = EagleMiniCPMMoE(
                num_experts=self.config.num_experts,
                top_k=self.config.num_experts_per_tok,
                hidden_size=self.config.hidden_size,
113
114
                intermediate_size=self.config.intermediate_size,
            )
115
116
117
118
119

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
120
        residual: torch.Tensor | None,
121
122
123
124
125
126
127
128
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
129
130
131
        hidden_states = residual + hidden_states * (
            self.config.scale_depth / math.sqrt(self.config.mup_denominator)
        )
132
133
134
135
136

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
137
138
139
        hidden_states = residual + hidden_states * (
            self.config.scale_depth / math.sqrt(self.config.mup_denominator)
        )
140
141
142
143
144
145

        return hidden_states, None


@support_torch_compile
class EagleMiniCPMModel(nn.Module):
146
147
148
    def __init__(
        self, *, vllm_config: VllmConfig, prefix: str = "", start_layer: int = 0
    ):
149
150
151
152
153
154
155
156
157
        super().__init__()

        config = vllm_config.speculative_config.draft_model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config
        self.cache_config = cache_config
        self.quant_config = quant_config
158
159
160

        self.vocab_size = config.vocab_size

161
162
163
        self.fc = torch.nn.Linear(
            self.config.hidden_size * 2, self.config.hidden_size, bias=False
        )
164
165
166
167
168
169
170
        self.input_norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.input_norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )
        self.num_experts = getattr(self.config, "num_experts", 0)
171
        self._init_layers(prefix, config, cache_config, quant_config, start_layer)
172
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
173
174
175
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], self.config.hidden_size
        )
176
177
178
179
180

    def _init_layers(
        self,
        prefix: str,
        config: PretrainedConfig,
181
182
        cache_config: CacheConfig | None,
        quant_config: QuantizationConfig | None,
183
184
        start_layer: int,
    ):
185
186
187
188
189
190
191
192
193
194
195
        self.eagle_layers = nn.ModuleList(
            [
                EagleMiniCPMDecoderLayer(
                    config,
                    cache_config,
                    quant_config,
                    f"{prefix}.eagle_layers.{i + start_layer}",
                )
                for i in range(self.config.num_hidden_layers)
            ]
        )
196

197
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
198
199
200
201
202
203
204
205
        embedding = self.embed_tokens(input_ids)
        return embedding * self.config.scale_emb

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
206
    ) -> torch.Tensor | IntermediateTensors:
207
        input_embeds = self.embed_input_ids(input_ids)
208
209
210
        input_embeds = self.input_norm1(input_embeds)
        hidden_states = self.input_norm2(hidden_states)

211
        hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1))
212
213
214
215
216
217
218
219
220
221
        residual = None
        for layer in self.eagle_layers:
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )

        return hidden_states, hidden_states

222
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
223
224
225
226
227
228
229
230
231
232
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        expert_params_mapping = [
            # (param_name, weight_name, expert_id)
233
234
235
236
237
            (
                "ws" if weight_name in ["w1", "w3"] else "w2s",
                f"experts.{expert_id}.{weight_name}.weight",
                expert_id,
            )
238
239
240
241
242
243
244
245
246
            for expert_id in range(self.num_experts)
            for weight_name in ["w1", "w2", "w3"]
        ]
        params_dict = dict(self.named_parameters())

        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
247
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
248
249
250
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
251
            for param_name, weight_name, shard_id in stacked_params_mapping:
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for param_name, weight_name, expert_id in expert_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    if is_pp_missing_parameter(name, self):
                        continue
                    param = params_dict[name]
                    weight_loader = param.weight_loader
273
274
275
                    weight_loader(
                        param, loaded_weight, weight_name, expert_id=expert_id
                    )
276
277
278
279
280
281
282
283
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
                    param = params_dict[name]
284
285
286
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
287
288
289
290
291
292

                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


293
class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle):
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.speculative_config.draft_model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.prefix = prefix
        self.vllm_config = vllm_config
        self.config = config
322

323
324
325
326
        self.cache_config = cache_config
        self.quant_config = quant_config

        target_layer_num = vllm_config.model_config.get_num_layers(
327
328
            vllm_config.parallel_config
        )
329

330
331
332
333
334
        self.model = self._init_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
            start_layer=target_layer_num,
        )
335
336

        self.lm_head = ParallelLMHead(
337
            config.vocab_size,
338
339
            config.hidden_size,
            quant_config=quant_config,
340
            prefix=maybe_prefix(prefix, "lm_head"),
341
342
343
344
345
        )
        if config.tie_word_embeddings:
            self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
        self.scale_width = self.config.hidden_size / self.config.dim_model_base

346
        self.logits_processor = LogitsProcessor(config.vocab_size)
347
        self.make_empty_intermediate_tensors = (
348
349
            self.model.make_empty_intermediate_tensors
        )
350

351
352
353
354
355
356
    def _init_model(
        self, *, vllm_config: VllmConfig, prefix: str = "", start_layer: int = 0
    ):
        return EagleMiniCPMModel(
            vllm_config=vllm_config, prefix=prefix, start_layer=start_layer
        )
357

358
359
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
360
361
362
363
364
365
366

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
367
        hidden_states, hidden_states2 = self.model(input_ids, positions, hidden_states)
368
369
370
371
372
373
374
        hidden_states = hidden_states / self.scale_width
        hidden_states2 = hidden_states2 / self.scale_width
        return hidden_states, hidden_states2

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
375
    ) -> torch.Tensor | None:
376
        logits = self.logits_processor(self.lm_head, hidden_states)
377
378
        return logits

379
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
380
381
382
383
384
        def transform(inputs):
            name, loaded_weight = inputs
            process_eagle_weight(self, name)
            return name, loaded_weight

385
386
        loader = AutoWeightsLoader(
            self,
387
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
388
        )
389
        return loader.load_weights(map(transform, weights))