deepseek_mtp.py 13.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
zhuwenwen's avatar
zhuwenwen committed
3
4
import os
import re
zhuwenwen's avatar
zhuwenwen committed
5

6
from collections.abc import Iterable
zhuwenwen's avatar
zhuwenwen committed
7
8
from typing import Iterable, Optional

9
10
11
12
13

import torch
import torch.nn as nn
from transformers import PretrainedConfig

14
from vllm.config import VllmConfig
15
16
17
18
19
20
21
22
23
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors

24
from vllm.compilation.decorators import support_torch_compile
25
26
from .deepseek_v2 import (DeepseekV2DecoderLayer,
                          get_spec_layer_idx_from_weight_name)
Jiayi Yao's avatar
Jiayi Yao committed
27
from .interfaces import SupportsPP
28
from .utils import maybe_prefix
zhuwenwen's avatar
zhuwenwen committed
29
from vllm import _custom_ops as ops
zhuwenwen's avatar
zhuwenwen committed
30
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51


class SharedHead(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
        super().__init__()
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.head = ParallelLMHead(config.vocab_size,
                                   config.hidden_size,
                                   quant_config=quant_config)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.norm(hidden_states)


class DeepSeekMultiTokenPredictorLayer(nn.Module):

52
    def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
53
        super().__init__()
54

55
56
57
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

58
59
60
61
62
        self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.eh_proj = nn.Linear(config.hidden_size * 2,
                                 config.hidden_size,
                                 bias=False)
63

64
        self.is_v32 = hasattr(config, "index_topk")
65
66
67
68
69
70
71
72
73
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
                device="cuda")
        else:
            topk_indices_buffer = None
74
        self.shared_head = SharedHead(config=config, quant_config=quant_config)
75
76
        self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
                                                topk_indices_buffer)
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
        inputs_embeds: Optional[torch.Tensor] = None,
        spec_step_index: int = 0,
    ) -> torch.Tensor:
        assert inputs_embeds is not None
        # masking inputs at position 0, as not needed by MTP
        inputs_embeds[positions == 0] = 0
        inputs_embeds = self.enorm(inputs_embeds)
        previous_hidden_states = self.hnorm(previous_hidden_states)

        hidden_states = self.eh_proj(
            torch.cat([inputs_embeds, previous_hidden_states], dim=-1))

        hidden_states, residual = self.mtp_block(positions=positions,
                                                 hidden_states=hidden_states,
                                                 residual=None)
        hidden_states = residual + hidden_states
99
        return hidden_states
100
101
102
103
104
105
106
107
108
109
110
111


class DeepSeekMultiTokenPredictor(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.mtp_start_layer_idx = config.num_hidden_layers
        self.num_mtp_layers = config.num_nextn_predict_layers
        # to map the exact layer index from weights
        self.layers = torch.nn.ModuleDict({
            str(idx):
112
113
            DeepSeekMultiTokenPredictorLayer(vllm_config,
                                             f"{prefix}.layers.{idx}")
114
115
116
            for idx in range(self.mtp_start_layer_idx,
                             self.mtp_start_layer_idx + self.num_mtp_layers)
        })
117
118
119
120
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
121
122
123
124
125
126
127
128
129
130
        self.logits_processor = LogitsProcessor(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
        inputs_embeds: Optional[torch.Tensor] = None,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
131
132
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
133
134
        current_step_idx = (spec_step_idx % self.num_mtp_layers)
        return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
135
136
137
138
            input_ids,
            positions,
            previous_hidden_states,
            inputs_embeds,
139
            current_step_idx,
140
141
142
143
144
145
146
        )

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
147
148
149
        current_step_idx = (spec_step_idx % self.num_mtp_layers)
        mtp_layer = self.layers[str(self.mtp_start_layer_idx +
                                    current_step_idx)]
150
        logits = self.logits_processor(mtp_layer.shared_head.head,
151
                                       mtp_layer.shared_head(hidden_states))
152
153
154
        return logits


155
@support_torch_compile
Jiayi Yao's avatar
Jiayi Yao committed
156
class DeepSeekMTP(nn.Module, SupportsPP):
157
158
159
160

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
zhuwenwen's avatar
zhuwenwen committed
161
162
163
164
165
166
167
        quant_config = vllm_config.quant_config

        self.quant_method = None
        if quant_config is not None:
            self.quant_method = quant_config.get_name()
            os.environ['LLAMA_NN'] = '0'
            os.environ['LM_NN'] = '0'
zhuwenwen's avatar
zhuwenwen committed
168
            # The AWQ layer of MTP uses BlockInt8W8A8.
yangql's avatar
yangql committed
169
            if self.quant_method == "moe_wna16" or self.quant_method == "awq_marlin":
zhuwenwen's avatar
zhuwenwen committed
170
                vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128])
zhuwenwen's avatar
zhuwenwen committed
171

172
173
174
        self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
                                                 prefix=maybe_prefix(
                                                     prefix, "model"))
zhuwenwen's avatar
zhuwenwen committed
175
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
176
177
178
179
180
181


    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
182
        hidden_states: torch.Tensor,
183
184
185
186
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
187
188
        hidden_states = self.model(input_ids, positions, hidden_states,
                                   inputs_embeds, spec_step_idx)
189
190
191
192
193
194
195
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
    ) -> Optional[torch.Tensor]:
196
        return self.model.compute_logits(hidden_states, spec_step_idx)
197

198
199
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
200
201
202
        stacked_params_mapping = [
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
203
204
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
205
206
207
208
209
210
211
212
213
        ]

        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.n_routed_experts)

        params_dict = dict(self.named_parameters())
214
        loaded_params: set[str] = set()
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is None:
                continue
            name = self._rewrite_spec_layer_name(spec_layer, name)
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if (("mlp.experts." in name) and name not in params_dict):
                    continue
234
                name_mapped = name.replace(weight_name, param_name)
235
236
237
238

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
                if ((param_name == "fused_qkv_a_proj")
239
                        and name_mapped not in params_dict):
240
                    continue
241
242
                else:
                    name = name_mapped
243

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)

                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
                                  name,
                                  shard_id=shard_id,
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

272
273
274
275
276
277
                    # According to DeepSeek-V3 Technical Report, MTP modules
                    # shares embedding layer. We only load the first weights.
                    if (spec_layer != self.model.mtp_start_layer_idx
                            and ".layers" not in name):
                        continue

278
279
280
281
282
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
            
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "self_attn.eh_proj.weight",
                "self_attn.q_proj.weight",
                "self_attn.q_a_proj.weight",
                "self_attn.q_b_proj.weight",
                "self_attn.kv_a_proj_with_mqa.weight",
                "self_attn.kv_b_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
                "mlp.gate.weight",
                "shared_experts.gate_up_proj.weight",
                "shared_experts.down_proj.weight",
                "shared_head.head.weight",
            ]

            combined_words = "|".join(lay_key_words)
            
            for layername in loaded_params:
                weight = params_dict[layername]
                matches = re.findall(combined_words, layername)
                if matches:
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)

315
316
317
318
319
320
        return loaded_params

    def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
        """
        Rewrite the weight name to match the format of the original model.
        Add .mtp_block for modules in transformer layer block for spec layer
321
        and rename shared layer weights to be top level.
322
323
324
325
        """
        spec_layer_weight_names = [
            "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
        ]
326
        shared_weight_names = ["embed_tokens"]
327
        spec_layer_weight = False
328
        shared_weight = False
329
330
331
        for weight_name in spec_layer_weight_names:
            if weight_name in name:
                spec_layer_weight = True
332
333
                if weight_name in shared_weight_names:
                    shared_weight = True
334
335
336
337
338
                break
        if not spec_layer_weight:
            # treat rest weights as weights for transformer layer block
            name = name.replace(f"model.layers.{spec_layer}.",
                                f"model.layers.{spec_layer}.mtp_block.")
339
340
341
        elif shared_weight:
            # treat shared weights as top level weights
            name = name.replace(f"model.layers.{spec_layer}.", "model.")
342
        return name