deepseek_mtp.py 13.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable
4
5
6
7
8

import torch
import torch.nn as nn
from transformers import PretrainedConfig

9
from vllm.compilation.decorators import support_torch_compile
10
from vllm.config import VllmConfig
11
from vllm.logger import init_logger
12
13
14
15
16
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
17
18
19
    ParallelLMHead,
    VocabParallelEmbedding,
)
20
21
22
23
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
24
from vllm.platforms import current_platform
25
26
from vllm.sequence import IntermediateTensors

27
28
from .deepseek_v2 import (
    DeepseekV2DecoderLayer,
29
30
    DeepseekV2MixtureOfExperts,
    DeepseekV2MoE,
31
32
    get_spec_layer_idx_from_weight_name,
)
Jiayi Yao's avatar
Jiayi Yao committed
33
from .interfaces import SupportsPP
34
35
from .utils import maybe_prefix

36
37
logger = init_logger(__name__)

38
39
40
41
42

class SharedHead(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
43
        prefix: str,
44
        quant_config: QuantizationConfig | None = None,
45
46
47
    ) -> None:
        super().__init__()
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
48
49
50
51
52
53
        self.head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "head"),
        )
54
55
56
57
58
59

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.norm(hidden_states)


class DeepSeekMultiTokenPredictorLayer(nn.Module):
60
    def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
61
        super().__init__()
62

63
64
        config = vllm_config.speculative_config.draft_model_config.hf_config
        self.config = config
65
66
        quant_config = vllm_config.quant_config

67
68
        self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
69
        self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
70

71
72
        self.device = current_platform.device_type

73
74
75
76
77
78
79
        self.is_v32 = hasattr(config, "index_topk")
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
80
                device=self.device,
81
            )
82
83
        else:
            topk_indices_buffer = None
84

85
86
87
88
        self.shared_head = SharedHead(
            config=config, prefix=prefix, quant_config=quant_config
        )
        self.mtp_block = DeepseekV2DecoderLayer(
89
90
91
92
            vllm_config,
            prefix,
            config=self.config,
            topk_indices_buffer=topk_indices_buffer,
93
        )
94
95
96
97
98
99

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
100
        inputs_embeds: torch.Tensor | None = None,
101
102
103
104
        spec_step_index: int = 0,
    ) -> torch.Tensor:
        assert inputs_embeds is not None
        # masking inputs at position 0, as not needed by MTP
105
        inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
106
107
108
109
        inputs_embeds = self.enorm(inputs_embeds)
        previous_hidden_states = self.hnorm(previous_hidden_states)

        hidden_states = self.eh_proj(
110
111
            torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
        )
112

113
114
115
        hidden_states, residual = self.mtp_block(
            positions=positions, hidden_states=hidden_states, residual=None
        )
116
        hidden_states = residual + hidden_states
117
        return hidden_states
118
119
120
121
122
123
124
125
126


class DeepSeekMultiTokenPredictor(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.mtp_start_layer_idx = config.num_hidden_layers
        self.num_mtp_layers = config.num_nextn_predict_layers
        # to map the exact layer index from weights
127

128
129
130
131
132
133
134
135
136
137
138
        self.layers = torch.nn.ModuleDict(
            {
                str(idx): DeepSeekMultiTokenPredictorLayer(
                    vllm_config, f"{prefix}.layers.{idx}"
                )
                for idx in range(
                    self.mtp_start_layer_idx,
                    self.mtp_start_layer_idx + self.num_mtp_layers,
                )
            }
        )
139
140
141
142
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
143
144
        self.logits_processor = LogitsProcessor(config.vocab_size)

145
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
146
147
        return self.embed_tokens(input_ids)

148
149
150
151
152
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
153
        inputs_embeds: torch.Tensor | None = None,
154
155
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
156
157
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
158
        current_step_idx = spec_step_idx % self.num_mtp_layers
159
        return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
160
161
162
163
            input_ids,
            positions,
            previous_hidden_states,
            inputs_embeds,
164
            current_step_idx,
165
166
167
168
169
170
171
        )

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
172
173
174
175
176
        current_step_idx = spec_step_idx % self.num_mtp_layers
        mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
        logits = self.logits_processor(
            mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
        )
177
178
179
        return logits


180
@support_torch_compile
181
class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
182
183
184
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
185
186
187
        self.model = DeepSeekMultiTokenPredictor(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        # Set MoE hyperparameters
        self.set_moe_parameters()

    def set_moe_parameters(self):
        self.expert_weights = []
        self.num_moe_layers = self.config.num_nextn_predict_layers
        self.num_expert_groups = self.config.n_group

        self.moe_layers = []
        self.moe_mlp_layers = []
        example_moe = None
        for layer in self.model.layers.values():
            assert isinstance(layer, DeepSeekMultiTokenPredictorLayer)
            layer = layer.mtp_block
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
                example_moe = layer.mlp
                self.moe_mlp_layers.append(layer.mlp)
                self.moe_layers.append(layer.mlp.experts)
        self.extract_moe_parameters(example_moe)
208

209
210
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
211

212
213
214
215
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
216
        hidden_states: torch.Tensor,
217
218
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
219
220
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
221
222
223
        hidden_states = self.model(
            input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
        )
224
225
226
227
228
229
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
230
    ) -> torch.Tensor | None:
231
        return self.model.compute_logits(hidden_states, spec_step_idx)
232

233
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
234
235
236
        stacked_params_mapping = [
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
237
238
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
239
240
241
242
243
244
        ]

        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
245
246
            num_experts=self.config.n_routed_experts,
        )
247
248

        params_dict = dict(self.named_parameters())
249
        loaded_params: set[str] = set()
250
251
252
253
254
255
256
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is None:
                continue
            name = self._rewrite_spec_layer_name(spec_layer, name)
257
            for param_name, weight_name, shard_id in stacked_params_mapping:
258
259
260
261
262
263
264
265
266
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
267
                if ("mlp.experts." in name) and name not in params_dict:
268
                    continue
269
                name_mapped = name.replace(weight_name, param_name)
270
271
272

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
273
274
275
                if (
                    param_name == "fused_qkv_a_proj"
                ) and name_mapped not in params_dict:
276
                    continue
277
278
                else:
                    name = name_mapped
279

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)

                    param = params_dict[name]
                    weight_loader = param.weight_loader
297
298
299
300
301
302
303
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
304
305
306
307
308
309
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

310
311
312
313
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

314
315
                    # According to DeepSeek-V3 Technical Report, MTP modules
                    # shares embedding layer. We only load the first weights.
316
317
318
319
                    if (
                        spec_layer != self.model.mtp_start_layer_idx
                        and ".layers" not in name
                    ):
320
321
                        continue

322
                    param = params_dict[name]
323
324
325
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
326
327
328
329
330
331
332
333
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

    def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
        """
        Rewrite the weight name to match the format of the original model.
        Add .mtp_block for modules in transformer layer block for spec layer
334
        and rename shared layer weights to be top level.
335
336
        """
        spec_layer_weight_names = [
337
338
339
340
341
            "embed_tokens",
            "enorm",
            "hnorm",
            "eh_proj",
            "shared_head",
342
        ]
343
        shared_weight_names = ["embed_tokens"]
344
        spec_layer_weight = False
345
        shared_weight = False
346
347
348
        for weight_name in spec_layer_weight_names:
            if weight_name in name:
                spec_layer_weight = True
349
350
                if weight_name in shared_weight_names:
                    shared_weight = True
351
352
353
                break
        if not spec_layer_weight:
            # treat rest weights as weights for transformer layer block
354
355
356
            name = name.replace(
                f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
            )
357
358
359
        elif shared_weight:
            # treat shared weights as top level weights
            name = name.replace(f"model.layers.{spec_layer}.", "model.")
360
        return name