"vllm/vscode:/vscode.git/clone" did not exist on "63b2206ad01499921428ba50c85a18c92772f26c"
ernie_mtp.py 10.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Copyright 2025 The Baidu team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Ernie-MTP model."""
25

26
27
28
29
30
31
from collections.abc import Iterable

import torch
import torch.nn as nn
from transformers import PretrainedConfig

32
from vllm.config import VllmConfig
33
34
35
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
36
37
38
    ParallelLMHead,
    VocabParallelEmbedding,
)
39
40
41
42
43
44
45
46
47
48
49
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsPP
from .llama import LlamaDecoderLayer
from .utils import is_pp_missing_parameter, maybe_prefix


class ErnieMultiTokenPredictorLayer(nn.Module):
    def __init__(
        self,
50
        vllm_config: VllmConfig,
51
52
53
        prefix: str,
    ) -> None:
        super().__init__()
54
        config = vllm_config.model_config.hf_config
55

56
57
58
59
60
        self.mtp_emb_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.mtp_hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.mtp_linear_proj = nn.Linear(
            config.hidden_size * 2, config.hidden_size, bias=False
        )
61
        self.mtp_block = LlamaDecoderLayer(vllm_config, prefix)
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

    def forward(
        self,
        inputs_embeds: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
        spec_step_index: int = 0,
    ) -> torch.Tensor:
        assert inputs_embeds is not None
        # masking inputs at position 0, as not needed by MTP
        inputs_embeds[positions == 0] = 0

        inputs_embeds = self.mtp_emb_norm(inputs_embeds)
        previous_hidden_states = self.mtp_hidden_norm(previous_hidden_states)

        hidden_states = self.mtp_linear_proj(
78
79
            torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
        )
80

81
82
83
        hidden_states, residual = self.mtp_block(
            positions=positions, hidden_states=hidden_states, residual=None
        )
84
85
86
87
88
89
90
91
92
93
94
95
96
        hidden_states = residual + hidden_states

        return hidden_states


class ErnieMultiTokenPredictor(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        self.mtp_start_layer_idx = config.num_hidden_layers
        self.num_mtp_layers = config.num_nextn_predict_layers
        # to map the exact layer index from weights
97
98
99
100
101
102
103
104
105
106
107
108
        self.layers = torch.nn.ModuleDict(
            {
                str(idx): ErnieMultiTokenPredictorLayer(
                    vllm_config,
                    f"{prefix}.layers.{idx}",
                )
                for idx in range(
                    self.mtp_start_layer_idx,
                    self.mtp_start_layer_idx + self.num_mtp_layers,
                )
            }
        )
109
110
111
112
113
114
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.logits_processor = LogitsProcessor(config.vocab_size)

115
116
117
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

118
119
120
121
122
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
123
        inputs_embeds: torch.Tensor | None = None,
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
        return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
            inputs_embeds,
            positions,
            previous_hidden_states,
            spec_step_idx,
        )

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        lm_head: ParallelLMHead,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
        self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
142
        logits = self.logits_processor(lm_head, hidden_states)
143
144
145
146
147
148
149
150
        return logits


class ErnieMTP(nn.Module, SupportsPP):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        self.config = vllm_config.model_config.hf_config
151
152
153
154
155
156
157
158
        self.model = ErnieMultiTokenPredictor(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
159
160
161
162

        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

163
164
165
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

166
167
168
169
170
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
171
172
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
173
174
175
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
        assert spec_step_idx == 0, "ernie_mtp only support predict one token"
176
177
178
        hidden_states = self.model(
            input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
        )
179
180
181
182
183
184
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
185
    ) -> torch.Tensor | None:
186
        return self.model.compute_logits(hidden_states, self.lm_head, spec_step_idx)
187

188
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
189
190
191
192
193
194
195
196
197
198
199
        stacked_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
200
            if self.config.tie_word_embeddings and name.endswith("lm_head.weight"):
201
202
203
204
205
206
                continue
            if "rotary_emb.inv_freq" in name:
                continue
            if "mtp" in name:
                name = self._rewrite_spec_layer_name(self.config, name)

207
            for param_name, weight_name, shard_id in stacked_params_mapping:
208
209
210
211
212
213
214
215
216
217
218
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                if "mtp" not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
219
                if ("mlp.experts." in name) and name not in params_dict:
220
221
222
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
223
224
225
                if (
                    name.endswith(".bias") or name.endswith("_bias")
                ) and name not in params_dict:
226
227
228
229
230
231
232
233
234
235
236
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
237
238
239
                if (
                    name.endswith(".bias") or name.endswith("_bias")
                ) and name not in params_dict:
240
241
242
243
244
245
246
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue

                # According to DeepSeek-V3 Technical Report, MTP modules
                # shares embedding layer. We only load the first weights.
247
248
249
                if "mtp_" not in name and (
                    "embed_tokens" not in name and "lm_head" not in name
                ):
250
251
252
                    continue

                param = params_dict[name]
253
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
254
255
256
257
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

258
    def _rewrite_spec_layer_name(self, config: PretrainedConfig, name: str) -> str:
259
260
261
262
        """
        Rewrite the weight name to match the format of the original model.
        """
        spec_layer_weight_names = [
263
264
265
266
            "embed_tokens",
            "mtp_emb_norm",
            "mtp_hidden_norm",
            "mtp_linear_proj",
267
268
269
270
271
272
        ]
        layer_idx = config.num_hidden_layers
        for weight_name in spec_layer_weight_names:
            if weight_name in name:
                name = name.replace(
                    f"model.{weight_name}.0.",
273
274
                    f"model.layers.{layer_idx}.{weight_name}.",
                )
275
                return name
276
277
278
        name = name.replace(
            "model.mtp_block.0.", f"model.layers.{layer_idx}.mtp_block."
        )
279
        return name