deepseek_mtp.py 22.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

zhuwenwen's avatar
zhuwenwen committed
4
5
import os
import re
zhuwenwen's avatar
zhuwenwen committed
6

7
8
import typing
from collections.abc import Callable, Iterable
9
10
11
12
13

import torch
import torch.nn as nn
from transformers import PretrainedConfig

王敏's avatar
王敏 committed
14
from vllm.forward_context import get_forward_context
15
from vllm._aiter_ops import rocm_aiter_ops
16
from vllm.compilation.decorators import support_torch_compile
17
from vllm.config import VllmConfig
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
20
21
22
23
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
24
25
26
    ParallelLMHead,
    VocabParallelEmbedding,
)
27
28
29
30
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
31
from vllm.platforms import current_platform
32
33
from vllm.sequence import IntermediateTensors

34
35
from .deepseek_v2 import (
    DeepseekV2DecoderLayer,
36
37
    DeepseekV2MixtureOfExperts,
    DeepseekV2MoE,
38
39
    get_spec_layer_idx_from_weight_name,
)
王敏's avatar
王敏 committed
40
41
42
from vllm.distributed import (tensor_model_parallel_all_gather, 
                              get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
43
from .utils import maybe_prefix
王敏's avatar
王敏 committed
44
from .interfaces import SupportsPP
zhuwenwen's avatar
zhuwenwen committed
45
from vllm import _custom_ops as ops
zhuwenwen's avatar
zhuwenwen committed
46
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
47
import vllm.envs as envs
48
from vllm.utils import direct_register_custom_op
49
logger = init_logger(__name__)
50
51
52
53
54
55


class SharedHead(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
56
        prefix: str,
57
        quant_config: QuantizationConfig | None = None,
58
59
60
    ) -> None:
        super().__init__()
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
61
62
63
64
65
66
        self.head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "head"),
        )
67
68
69
70
71
72

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.norm(hidden_states)


class DeepSeekMultiTokenPredictorLayer(nn.Module):
73
    def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
74
        super().__init__()
75

76
77
        config = vllm_config.speculative_config.draft_model_config.hf_config
        self.config = config
78
79
        quant_config = vllm_config.quant_config

80
81
        self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
82
        self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
83

84
        self.device = current_platform.device_type
85

yangql's avatar
yangql committed
86
        #添加判断,默认开启DSA
87
        force_disable_dsa = envs.VLLM_DISABLE_DSA
yangql's avatar
yangql committed
88
        self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa
89
90
91
92
93
94
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
95
                device=self.device,
96
            )
97
98
        else:
            topk_indices_buffer = None
99

100
101
102
103
        self.shared_head = SharedHead(
            config=config, prefix=prefix, quant_config=quant_config
        )
        self.mtp_block = DeepseekV2DecoderLayer(
104
105
106
107
            vllm_config,
            prefix,
            config=self.config,
            topk_indices_buffer=topk_indices_buffer,
108
        )
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    def fuse_fill_rms_x2_concat(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
                                    previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor, 
                                    weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
        from lightop import fuse_fill_rms_x2_concat
        fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, weight_inputs_embeds, weight_previous_hidden_states, epsilon)

    def fuse_fill_rms_x2_concat_fake(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
                                    previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor, 
                                    weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
        pass

    direct_register_custom_op(
        op_name="fuse_fill_rms_x2_concat",
        op_func=fuse_fill_rms_x2_concat,
        mutates_args=["hidden_states_fuse", "inputs_embeds"], 
        fake_impl=fuse_fill_rms_x2_concat_fake,
    )

127
128
129
130
131
132

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
133
        inputs_embeds: torch.Tensor | None = None,
134
135
136
137
        spec_step_index: int = 0,
    ) -> torch.Tensor:
        assert inputs_embeds is not None
        # masking inputs at position 0, as not needed by MTP
138
139
140
141
142
143
144
145
        if envs.VLLM_USE_FUSED_FILL_RMS_CAT:
            hidden_states_fuse = torch.empty(inputs_embeds.shape[0], inputs_embeds.shape[1]*2, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
            torch.ops.vllm.fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, self.enorm.weight, self.hnorm.weight, self.enorm.variance_epsilon)
            hidden_states = self.eh_proj(hidden_states_fuse)
        else:
            inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
            inputs_embeds = self.enorm(inputs_embeds)
            previous_hidden_states = self.hnorm(previous_hidden_states)
146

147
148
149
            hidden_states = self.eh_proj(
                torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
            )
150

151
152
153
        hidden_states, residual = self.mtp_block(
            positions=positions, hidden_states=hidden_states, residual=None
        )
154
        hidden_states = residual + hidden_states
155
        return hidden_states
156
157
158
159
160
161
162
163
164


class DeepSeekMultiTokenPredictor(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.mtp_start_layer_idx = config.num_hidden_layers
        self.num_mtp_layers = config.num_nextn_predict_layers
        # to map the exact layer index from weights
165

166
167
168
169
170
171
172
173
174
175
176
        self.layers = torch.nn.ModuleDict(
            {
                str(idx): DeepSeekMultiTokenPredictorLayer(
                    vllm_config, f"{prefix}.layers.{idx}"
                )
                for idx in range(
                    self.mtp_start_layer_idx,
                    self.mtp_start_layer_idx + self.num_mtp_layers,
                )
            }
        )
177
178
179
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
180
            prefix=maybe_prefix(prefix, "embed_tokens"),
181
        )
182
183
        self.logits_processor = LogitsProcessor(config.vocab_size)

王敏's avatar
王敏 committed
184
185
186
        self.tp_rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()

187
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
188
189
        return self.embed_tokens(input_ids)

190
191
192
193
194
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
195
        inputs_embeds: torch.Tensor | None = None,
196
197
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
198
199
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
200
        current_step_idx = spec_step_idx % self.num_mtp_layers
王敏's avatar
王敏 committed
201
202
203
204
205
206
207
208
209
210
211
212
213
        enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
        if enable_mla_cp:
            inputs_embeds_per_rank = torch.chunk(inputs_embeds, chunks=self.tp_size, dim=0)
            inputs_embeds = inputs_embeds_per_rank[self.tp_rank].contiguous()

            previous_hidden_states_per_rank = torch.chunk(previous_hidden_states, chunks=self.tp_size, dim=0)
            previous_hidden_states = previous_hidden_states_per_rank[self.tp_rank].contiguous()

            if positions is not None:
                positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
                positions = positions_per_rank[self.tp_rank].contiguous()

        hidden_states = self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
214
215
216
217
            input_ids,
            positions,
            previous_hidden_states,
            inputs_embeds,
218
            current_step_idx,
219
220
        )

王敏's avatar
王敏 committed
221
222
223
224
225
        if enable_mla_cp:
            hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)

        return hidden_states

226
227
228
229
230
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
231
232
233
234
235
        current_step_idx = spec_step_idx % self.num_mtp_layers
        mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
        logits = self.logits_processor(
            mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
        )
236
237
238
        return logits


239
@support_torch_compile
王敏's avatar
王敏 committed
240
class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
241
242
243
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
244
245
246
        self.model = DeepSeekMultiTokenPredictor(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
247
248
        # Set MoE hyperparameters
        self.set_moe_parameters()
zhuwenwen's avatar
zhuwenwen committed
249
250
251
252
253
254
255
256
        quant_config = vllm_config.quant_config

        self.quant_method = None
        if quant_config is not None:
            self.quant_method = quant_config.get_name()
            os.environ['LLAMA_NN'] = '0'
            os.environ['LM_NN'] = '0'
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
257
                
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

    def set_moe_parameters(self):
        self.expert_weights = []
        self.num_moe_layers = self.config.num_nextn_predict_layers
        self.num_expert_groups = self.config.n_group

        self.moe_layers = []
        self.moe_mlp_layers = []
        example_moe = None
        for layer in self.model.layers.values():
            assert isinstance(layer, DeepSeekMultiTokenPredictorLayer)
            layer = layer.mtp_block
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
                example_moe = layer.mlp
                self.moe_mlp_layers.append(layer.mlp)
                self.moe_layers.append(layer.mlp.experts)
        self.extract_moe_parameters(example_moe)
276

277
278
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
279
280
281

    def forward(
        self,
282
        input_ids: torch.Tensor | None,
283
        positions: torch.Tensor,
284
        hidden_states: torch.Tensor,
285
286
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
287
288
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
289
290
291
        hidden_states = self.model(
            input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
        )
292
293
294
295
296
297
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
298
    ) -> torch.Tensor | None:
299
        return self.model.compute_logits(hidden_states, spec_step_idx)
300

301
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
302
303
304
        rocm_aiter_moe_shared_expert_enabled = (
            rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
        )
305
306
307
        stacked_params_mapping = [
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
308
309
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
310
311
        ]

312
        expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
313
            self,
314
315
316
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
317
318
319
320
321
322
            num_experts=self.config.n_routed_experts
            + (
                self.config.n_shared_experts
                if rocm_aiter_moe_shared_expert_enabled
                else 0
            ),
323
        )
324
325

        params_dict = dict(self.named_parameters())
326
        loaded_params: set[str] = set()
yangql's avatar
yangql committed
327
328
329
330

        # 判断是否加载"indexer"权重
        model_has_indexer = any("indexer" in param_name for param_name in params_dict.keys())

331
332
333
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
yangql's avatar
yangql committed
334
335
336
337
338

            #  跳过加载"indexer"权重
            if "indexer" in name and not model_has_indexer:
                logger.info(f"Skipping indexer weight (DSA disabled): {name}")
                continue
339
340
341
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is None:
                continue
342
343
344
            is_fusion_moe_shared_experts_layer = (
                rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
            )
345
            name = self._rewrite_spec_layer_name(spec_layer, name)
346
            for param_name, weight_name, shard_id in stacked_params_mapping:
347
348
349
350
351
352
353
354
355
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
356
                if ("mlp.experts." in name) and name not in params_dict:
357
                    continue
358
                if is_fusion_moe_shared_experts_layer:
359
                    continue
360
                name_mapped = name.replace(weight_name, param_name)
361
362
363

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
364
365
366
                if (
                    param_name == "fused_qkv_a_proj"
                ) and name_mapped not in params_dict:
367
                    continue
368
369
                else:
                    name = name_mapped
370

371
372
373
374
375
376
377
378
379
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
380
381
382
383
384
385
386
387
388
389
390
391
392
393
                # Special handling: when AITER fusion_shared_experts is enabled,
                # checkpoints may provide a single widened shared_experts tensor
                # without explicit expert indices
                # (e.g. ...mlp.shared_experts.gate_proj.weight).
                # For models with multiple shared experts, split that tensor
                # evenly into per-shared-expert slices and load them into
                # appended expert slots mlp.experts.{n_routed_experts + j}.*
                # accordingly.
                num_chunks = 1
                if is_fusion_moe_shared_experts_layer:
                    num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
                    # Determine split axis based on op type
                    # gate/up: ColumnParallel → split along dim 0
                    # down: RowParallel → split along dim 1
394
395
396
397
398
                    split_dim = (
                        1
                        if ("down_proj.weight" in name and loaded_weight.ndim > 1)
                        else 0
                    )
399
400
401
402
                    total = loaded_weight.shape[split_dim]
                    assert total % num_chunks == 0, (
                        f"Shared expert weight dim {total} "
                        f"not divisible by num_chunks {num_chunks}"
403
                    )
404
405
406
407
408
409
410
                    chunk_size = total // num_chunks

                for j in range(num_chunks):
                    chunk_name = name
                    weight_to_load = loaded_weight

                    if is_fusion_moe_shared_experts_layer:
zhuwenwen's avatar
zhuwenwen committed
411
412
413
414
                        if split_dim == 0:
                            weight_to_load = loaded_weight[
                                j * chunk_size : (j + 1) * chunk_size, :
                            ]
415
                        else:
zhuwenwen's avatar
zhuwenwen committed
416
417
418
                            weight_to_load = loaded_weight[
                                :, j * chunk_size : (j + 1) * chunk_size
                            ]
419
420
421
422
423
424
425
426
427
428
                        # Synthesize an expert-style name so expert mapping
                        # can route it
                        chunk_name = name.replace(
                            "mlp.shared_experts",
                            f"mlp.experts.{self.config.n_routed_experts + j}",
                        )

                    # Use expert_params_mapping to locate the destination
                    # param and delegate to its expert-aware weight_loader
                    # with expert_id.
429
                    is_expert_weight = False
430
431
432
433
434
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in chunk_name:
                            continue

435
436
437
438
                        # Anyway, this is an expert weight and should not be
                        # attempted to load as other weights later
                        is_expert_weight = True

439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
                        # Do not modify `name` since the loop may continue here
                        # Instead, create a new variable
                        name_mapped = chunk_name.replace(weight_name, param_name)

                        param = params_dict[name_mapped]
                        # We should ask the weight loader to return success or
                        # not here since otherwise we may skip experts with
                        # other available replicas.
                        weight_loader = typing.cast(
                            Callable[..., bool], param.weight_loader
                        )
                        success = weight_loader(
                            param,
                            weight_to_load,
                            name_mapped,
                            shard_id=shard_id,
                            expert_id=expert_id,
                            return_success=True,
                        )
                        if success:
                            if not is_fusion_moe_shared_experts_layer:
                                name = name_mapped
                            else:
                                loaded_params.add(name_mapped)
                            break
                    else:
465
466
467
468
469
470
                        if is_expert_weight:
                            # We've checked that this is an expert weight
                            # However it's not mapped locally to this rank
                            # So we simply skip it
                            continue

471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue

                        name = maybe_remap_kv_scale_name(name, params_dict)
                        if name is None:
                            continue

                        # According to DeepSeek-V3 Technical Report, MTP modules
                        # shares embedding layer. We only load the first weights.
                        if (
                            spec_layer != self.model.mtp_start_layer_idx
                            and ".layers" not in name
                        ):
                            continue

                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
            if not is_fusion_moe_shared_experts_layer:
                loaded_params.add(name)
494
        
zhuwenwen's avatar
zhuwenwen committed
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "self_attn.eh_proj.weight",
                "self_attn.q_proj.weight",
                "self_attn.q_a_proj.weight",
                "self_attn.q_b_proj.weight",
                "self_attn.kv_a_proj_with_mqa.weight",
                "self_attn.kv_b_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
                "mlp.gate.weight",
                "shared_experts.gate_up_proj.weight",
                "shared_experts.down_proj.weight",
                "shared_head.head.weight",
            ]

            combined_words = "|".join(lay_key_words)
            
            for layername in loaded_params:
                weight = params_dict[layername]
                matches = re.findall(combined_words, layername)
                if matches:
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
525
526
527
528
529
530
        return loaded_params

    def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
        """
        Rewrite the weight name to match the format of the original model.
        Add .mtp_block for modules in transformer layer block for spec layer
531
        and rename shared layer weights to be top level.
532
533
        """
        spec_layer_weight_names = [
534
535
536
537
538
            "embed_tokens",
            "enorm",
            "hnorm",
            "eh_proj",
            "shared_head",
539
        ]
540
        shared_weight_names = ["embed_tokens"]
541
        spec_layer_weight = False
542
        shared_weight = False
543
544
545
        for weight_name in spec_layer_weight_names:
            if weight_name in name:
                spec_layer_weight = True
546
547
                if weight_name in shared_weight_names:
                    shared_weight = True
548
549
550
                break
        if not spec_layer_weight:
            # treat rest weights as weights for transformer layer block
551
552
553
            name = name.replace(
                f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
            )
554
555
556
        elif shared_weight:
            # treat shared weights as top level weights
            name = name.replace(f"model.layers.{spec_layer}.", "model.")
557
        return name