deepseek_mtp.py 12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from collections.abc import Iterable
from typing import Optional
5
6
7
8
9

import torch
import torch.nn as nn
from transformers import PretrainedConfig

10
from vllm.compilation.decorators import support_torch_compile
11
from vllm.config import VllmConfig
12
13
14
15
16
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
17
18
19
    ParallelLMHead,
    VocabParallelEmbedding,
)
20
21
22
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors

23
from .deepseek_v2 import DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name
Jiayi Yao's avatar
Jiayi Yao committed
24
from .interfaces import SupportsPP
25
26
27
28
29
30
31
from .utils import maybe_prefix


class SharedHead(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
32
        prefix: str,
33
34
35
36
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
        super().__init__()
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
37
38
39
40
41
42
        self.head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "head"),
        )
43
44
45
46
47
48

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.norm(hidden_states)


class DeepSeekMultiTokenPredictorLayer(nn.Module):
49
    def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
50
        super().__init__()
51

52
53
        config = vllm_config.speculative_config.draft_model_config.hf_config
        self.config = config
54
55
        quant_config = vllm_config.quant_config

56
57
        self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
58
        self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
59
60
61
62
63
64
65
66

        self.is_v32 = hasattr(config, "index_topk")
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
67
68
                device="cuda",
            )
69
70
        else:
            topk_indices_buffer = None
71

72
73
74
75
        self.shared_head = SharedHead(
            config=config, prefix=prefix, quant_config=quant_config
        )
        self.mtp_block = DeepseekV2DecoderLayer(
76
77
78
79
            vllm_config,
            prefix,
            config=self.config,
            topk_indices_buffer=topk_indices_buffer,
80
        )
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
        inputs_embeds: Optional[torch.Tensor] = None,
        spec_step_index: int = 0,
    ) -> torch.Tensor:
        assert inputs_embeds is not None
        # masking inputs at position 0, as not needed by MTP
        inputs_embeds[positions == 0] = 0
        inputs_embeds = self.enorm(inputs_embeds)
        previous_hidden_states = self.hnorm(previous_hidden_states)

        hidden_states = self.eh_proj(
97
98
            torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
        )
99

100
101
102
        hidden_states, residual = self.mtp_block(
            positions=positions, hidden_states=hidden_states, residual=None
        )
103
        hidden_states = residual + hidden_states
104
        return hidden_states
105
106
107
108
109
110
111
112
113


class DeepSeekMultiTokenPredictor(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.mtp_start_layer_idx = config.num_hidden_layers
        self.num_mtp_layers = config.num_nextn_predict_layers
        # to map the exact layer index from weights
114
115
116
117
118
119
120
121
122
123
124
        self.layers = torch.nn.ModuleDict(
            {
                str(idx): DeepSeekMultiTokenPredictorLayer(
                    vllm_config, f"{prefix}.layers.{idx}"
                )
                for idx in range(
                    self.mtp_start_layer_idx,
                    self.mtp_start_layer_idx + self.num_mtp_layers,
                )
            }
        )
125
126
127
128
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
129
130
        self.logits_processor = LogitsProcessor(config.vocab_size)

131
132
133
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

134
135
136
137
138
139
140
141
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
        inputs_embeds: Optional[torch.Tensor] = None,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
142
143
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
144
        current_step_idx = spec_step_idx % self.num_mtp_layers
145
        return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
146
147
148
149
            input_ids,
            positions,
            previous_hidden_states,
            inputs_embeds,
150
            current_step_idx,
151
152
153
154
155
156
157
        )

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
158
159
160
161
162
        current_step_idx = spec_step_idx % self.num_mtp_layers
        mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
        logits = self.logits_processor(
            mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
        )
163
164
165
        return logits


166
@support_torch_compile
Jiayi Yao's avatar
Jiayi Yao committed
167
class DeepSeekMTP(nn.Module, SupportsPP):
168
169
170
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
171
172
173
        self.model = DeepSeekMultiTokenPredictor(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
174

175
176
177
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

178
179
180
181
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
182
        hidden_states: torch.Tensor,
183
184
185
186
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
187
188
189
        hidden_states = self.model(
            input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
        )
190
191
192
193
194
195
196
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
    ) -> Optional[torch.Tensor]:
197
        return self.model.compute_logits(hidden_states, spec_step_idx)
198

199
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
200
201
202
        stacked_params_mapping = [
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
203
204
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
205
206
207
208
209
210
        ]

        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
211
212
            num_experts=self.config.n_routed_experts,
        )
213
214

        params_dict = dict(self.named_parameters())
215
        loaded_params: set[str] = set()
216
217
218
219
220
221
222
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is None:
                continue
            name = self._rewrite_spec_layer_name(spec_layer, name)
223
            for param_name, weight_name, shard_id in stacked_params_mapping:
224
225
226
227
228
229
230
231
232
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
233
                if ("mlp.experts." in name) and name not in params_dict:
234
                    continue
235
                name_mapped = name.replace(weight_name, param_name)
236
237
238

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
239
240
241
                if (
                    param_name == "fused_qkv_a_proj"
                ) and name_mapped not in params_dict:
242
                    continue
243
244
                else:
                    name = name_mapped
245

246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)

                    param = params_dict[name]
                    weight_loader = param.weight_loader
263
264
265
266
267
268
269
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
270
271
272
273
274
275
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

276
277
                    # According to DeepSeek-V3 Technical Report, MTP modules
                    # shares embedding layer. We only load the first weights.
278
279
280
281
                    if (
                        spec_layer != self.model.mtp_start_layer_idx
                        and ".layers" not in name
                    ):
282
283
                        continue

284
                    param = params_dict[name]
285
286
287
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
288
289
290
291
292
293
294
295
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

    def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
        """
        Rewrite the weight name to match the format of the original model.
        Add .mtp_block for modules in transformer layer block for spec layer
296
        and rename shared layer weights to be top level.
297
298
        """
        spec_layer_weight_names = [
299
300
301
302
303
            "embed_tokens",
            "enorm",
            "hnorm",
            "eh_proj",
            "shared_head",
304
        ]
305
        shared_weight_names = ["embed_tokens"]
306
        spec_layer_weight = False
307
        shared_weight = False
308
309
310
        for weight_name in spec_layer_weight_names:
            if weight_name in name:
                spec_layer_weight = True
311
312
                if weight_name in shared_weight_names:
                    shared_weight = True
313
314
315
                break
        if not spec_layer_weight:
            # treat rest weights as weights for transformer layer block
316
317
318
            name = name.replace(
                f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
            )
319
320
321
        elif shared_weight:
            # treat shared weights as top level weights
            name = name.replace(f"model.layers.{spec_layer}.", "model.")
322
        return name