"plugins/rpmd/vscode:/vscode.git/clone" did not exist on "49e268f119466c18dfb069d731f6b4b36c35cec9"
deepseek_mtp.py 23.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

zhuwenwen's avatar
zhuwenwen committed
4
5
import os
import re
zhuwenwen's avatar
zhuwenwen committed
6

7
8
import typing
from collections.abc import Callable, Iterable
9
10
11
12
13

import torch
import torch.nn as nn
from transformers import PretrainedConfig

14
from vllm.forward_context import get_forward_context
15
from vllm._aiter_ops import rocm_aiter_ops
16
from vllm.compilation.decorators import support_torch_compile
17
from vllm.config import VllmConfig
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
20
21
22
23
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
24
25
26
    ParallelLMHead,
    VocabParallelEmbedding,
)
27
28
29
30
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
31
from vllm.platforms import current_platform
32
33
from vllm.sequence import IntermediateTensors

34
35
from .deepseek_v2 import (
    DeepseekV2DecoderLayer,
36
37
    DeepseekV2MixtureOfExperts,
    DeepseekV2MoE,
38
39
    get_spec_layer_idx_from_weight_name,
)
40
41
42
from vllm.distributed import (tensor_model_parallel_all_gather, 
                              get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
43
from .utils import maybe_prefix
王敏's avatar
王敏 committed
44
from .interfaces import SupportsPP
zhuwenwen's avatar
zhuwenwen committed
45
from vllm import _custom_ops as ops
zhuwenwen's avatar
zhuwenwen committed
46
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
47
import vllm.envs as envs
48
from vllm.utils import direct_register_custom_op
49
logger = init_logger(__name__)
50
51
52
53
54
55


class SharedHead(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
56
        prefix: str,
57
        quant_config: QuantizationConfig | None = None,
58
59
60
    ) -> None:
        super().__init__()
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
61
62
63
64
65
66
        self.head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "head"),
        )
67
68
69
70
71
72

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.norm(hidden_states)


class DeepSeekMultiTokenPredictorLayer(nn.Module):
73
    def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
74
        super().__init__()
75

76
77
        config = vllm_config.speculative_config.draft_model_config.hf_config
        self.config = config
78
79
        quant_config = vllm_config.quant_config

80
81
        self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
82
        self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
83

84
        self.device = current_platform.device_type
85

yangql's avatar
yangql committed
86
        #添加判断,默认开启DSA
87
        force_disable_dsa = envs.VLLM_DISABLE_DSA
yangql's avatar
yangql committed
88
        self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa
89
90
91
92
93
94
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
95
                device=self.device,
96
            )
97
98
        else:
            topk_indices_buffer = None
99

100
101
102
103
        self.shared_head = SharedHead(
            config=config, prefix=prefix, quant_config=quant_config
        )
        self.mtp_block = DeepseekV2DecoderLayer(
104
105
106
107
            vllm_config,
            prefix,
            config=self.config,
            topk_indices_buffer=topk_indices_buffer,
108
        )
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    def fuse_fill_rms_x2_concat(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
                                    previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor, 
                                    weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
        from lightop import fuse_fill_rms_x2_concat
        fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, weight_inputs_embeds, weight_previous_hidden_states, epsilon)

    def fuse_fill_rms_x2_concat_fake(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
                                    previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor, 
                                    weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
        pass

    direct_register_custom_op(
        op_name="fuse_fill_rms_x2_concat",
        op_func=fuse_fill_rms_x2_concat,
        mutates_args=["hidden_states_fuse", "inputs_embeds"], 
        fake_impl=fuse_fill_rms_x2_concat_fake,
    )

127
128
129
130
131
132

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
133
        inputs_embeds: torch.Tensor | None = None,
134
135
136
137
        spec_step_index: int = 0,
    ) -> torch.Tensor:
        assert inputs_embeds is not None
        # masking inputs at position 0, as not needed by MTP
138
139
140
141
142
143
144
145
        if envs.VLLM_USE_FUSED_FILL_RMS_CAT:
            hidden_states_fuse = torch.empty(inputs_embeds.shape[0], inputs_embeds.shape[1]*2, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
            torch.ops.vllm.fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, self.enorm.weight, self.hnorm.weight, self.enorm.variance_epsilon)
            hidden_states = self.eh_proj(hidden_states_fuse)
        else:
            inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
            inputs_embeds = self.enorm(inputs_embeds)
            previous_hidden_states = self.hnorm(previous_hidden_states)
146

147
148
149
            hidden_states = self.eh_proj(
                torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
            )
150

151
152
153
        hidden_states, residual = self.mtp_block(
            positions=positions, hidden_states=hidden_states, residual=None
        )
154
        hidden_states = residual + hidden_states
155
        return hidden_states
156
157
158
159
160
161
162
163
164


class DeepSeekMultiTokenPredictor(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.mtp_start_layer_idx = config.num_hidden_layers
        self.num_mtp_layers = config.num_nextn_predict_layers
        # to map the exact layer index from weights
165

166
167
168
169
170
171
172
173
174
175
176
        self.layers = torch.nn.ModuleDict(
            {
                str(idx): DeepSeekMultiTokenPredictorLayer(
                    vllm_config, f"{prefix}.layers.{idx}"
                )
                for idx in range(
                    self.mtp_start_layer_idx,
                    self.mtp_start_layer_idx + self.num_mtp_layers,
                )
            }
        )
177
178
179
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
180
            prefix=maybe_prefix(prefix, "embed_tokens"),
181
        )
182
183
        self.logits_processor = LogitsProcessor(config.vocab_size)

184
185
186
        self.tp_rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()

187
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
188
189
        return self.embed_tokens(input_ids)

190
191
192
193
194
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
195
        inputs_embeds: torch.Tensor | None = None,
196
197
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
198
199
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
200
        current_step_idx = spec_step_idx % self.num_mtp_layers
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        enable_lightly_cp = get_forward_context().enable_lightly_cp
        if enable_lightly_cp:
            scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
            if scatter_indexes_tensor is None:
                inputs_embeds_per_rank = torch.chunk(inputs_embeds, chunks=self.tp_size, dim=0)
                inputs_embeds = inputs_embeds_per_rank[self.tp_rank].contiguous()

                previous_hidden_states_per_rank = torch.chunk(previous_hidden_states, chunks=self.tp_size, dim=0)
                previous_hidden_states = previous_hidden_states_per_rank[self.tp_rank].contiguous()

                if positions is not None:
                    positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
                    positions = positions_per_rank[self.tp_rank].contiguous()
            else:
                scatter_indexes_tensor = torch.where(scatter_indexes_tensor == -1, 0, scatter_indexes_tensor)
                inputs_embeds = torch.index_select(inputs_embeds, 0, scatter_indexes_tensor)

                previous_hidden_states = torch.index_select(previous_hidden_states, 0, scatter_indexes_tensor)
                if positions is not None:
                    positions = torch.index_select(positions, 0, scatter_indexes_tensor)

        hidden_states = self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
223
224
225
226
            input_ids,
            positions,
            previous_hidden_states,
            inputs_embeds,
227
            current_step_idx,
228
229
        )

230
231
232
233
234
235
236
237
        if enable_lightly_cp:
            hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
            gather_indexes_tensor = get_forward_context().gather_indexes_tensor
            if gather_indexes_tensor is not None:
                hidden_states = torch.index_select(hidden_states, 0, gather_indexes_tensor)

        return hidden_states

238
239
240
241
242
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
243
244
245
246
247
        current_step_idx = spec_step_idx % self.num_mtp_layers
        mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
        logits = self.logits_processor(
            mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
        )
248
249
250
        return logits


251
@support_torch_compile
王敏's avatar
王敏 committed
252
class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
253
254
255
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
256
257
258
        self.model = DeepSeekMultiTokenPredictor(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
259
260
        # Set MoE hyperparameters
        self.set_moe_parameters()
zhuwenwen's avatar
zhuwenwen committed
261
262
263
264
265
266
267
268
        quant_config = vllm_config.quant_config

        self.quant_method = None
        if quant_config is not None:
            self.quant_method = quant_config.get_name()
            os.environ['LLAMA_NN'] = '0'
            os.environ['LM_NN'] = '0'
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
269
                
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287

    def set_moe_parameters(self):
        self.expert_weights = []
        self.num_moe_layers = self.config.num_nextn_predict_layers
        self.num_expert_groups = self.config.n_group

        self.moe_layers = []
        self.moe_mlp_layers = []
        example_moe = None
        for layer in self.model.layers.values():
            assert isinstance(layer, DeepSeekMultiTokenPredictorLayer)
            layer = layer.mtp_block
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
                example_moe = layer.mlp
                self.moe_mlp_layers.append(layer.mlp)
                self.moe_layers.append(layer.mlp.experts)
        self.extract_moe_parameters(example_moe)
288

289
290
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
291
292
293

    def forward(
        self,
294
        input_ids: torch.Tensor | None,
295
        positions: torch.Tensor,
296
        hidden_states: torch.Tensor,
297
298
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
299
300
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
301
302
303
        hidden_states = self.model(
            input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
        )
304
305
306
307
308
309
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
310
    ) -> torch.Tensor | None:
311
        return self.model.compute_logits(hidden_states, spec_step_idx)
312

313
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
314
315
316
        rocm_aiter_moe_shared_expert_enabled = (
            rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
        )
317
318
319
        stacked_params_mapping = [
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
320
321
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
322
323
        ]

324
        expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
325
            self,
326
327
328
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
329
330
331
332
333
334
            num_experts=self.config.n_routed_experts
            + (
                self.config.n_shared_experts
                if rocm_aiter_moe_shared_expert_enabled
                else 0
            ),
335
        )
336
337

        params_dict = dict(self.named_parameters())
338
        loaded_params: set[str] = set()
yangql's avatar
yangql committed
339
340
341
342

        # 判断是否加载"indexer"权重
        model_has_indexer = any("indexer" in param_name for param_name in params_dict.keys())

343
344
345
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
yangql's avatar
yangql committed
346
347
348
349
350

            #  跳过加载"indexer"权重
            if "indexer" in name and not model_has_indexer:
                logger.info(f"Skipping indexer weight (DSA disabled): {name}")
                continue
351
352
353
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is None:
                continue
354
355
356
            is_fusion_moe_shared_experts_layer = (
                rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
            )
357
            name = self._rewrite_spec_layer_name(spec_layer, name)
358
            for param_name, weight_name, shard_id in stacked_params_mapping:
359
360
361
362
363
364
365
366
367
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
368
                if ("mlp.experts." in name) and name not in params_dict:
369
                    continue
370
                if is_fusion_moe_shared_experts_layer:
371
                    continue
372
                name_mapped = name.replace(weight_name, param_name)
373
374
375

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
376
377
378
                if (
                    param_name == "fused_qkv_a_proj"
                ) and name_mapped not in params_dict:
379
                    continue
380
381
                else:
                    name = name_mapped
382

383
384
385
386
387
388
389
390
391
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
392
393
394
395
396
397
398
399
400
401
402
403
404
405
                # Special handling: when AITER fusion_shared_experts is enabled,
                # checkpoints may provide a single widened shared_experts tensor
                # without explicit expert indices
                # (e.g. ...mlp.shared_experts.gate_proj.weight).
                # For models with multiple shared experts, split that tensor
                # evenly into per-shared-expert slices and load them into
                # appended expert slots mlp.experts.{n_routed_experts + j}.*
                # accordingly.
                num_chunks = 1
                if is_fusion_moe_shared_experts_layer:
                    num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
                    # Determine split axis based on op type
                    # gate/up: ColumnParallel → split along dim 0
                    # down: RowParallel → split along dim 1
406
407
408
409
410
                    split_dim = (
                        1
                        if ("down_proj.weight" in name and loaded_weight.ndim > 1)
                        else 0
                    )
411
412
413
414
                    total = loaded_weight.shape[split_dim]
                    assert total % num_chunks == 0, (
                        f"Shared expert weight dim {total} "
                        f"not divisible by num_chunks {num_chunks}"
415
                    )
416
417
418
419
420
421
422
                    chunk_size = total // num_chunks

                for j in range(num_chunks):
                    chunk_name = name
                    weight_to_load = loaded_weight

                    if is_fusion_moe_shared_experts_layer:
zhuwenwen's avatar
zhuwenwen committed
423
424
425
426
                        if split_dim == 0:
                            weight_to_load = loaded_weight[
                                j * chunk_size : (j + 1) * chunk_size, :
                            ]
427
                        else:
zhuwenwen's avatar
zhuwenwen committed
428
429
430
                            weight_to_load = loaded_weight[
                                :, j * chunk_size : (j + 1) * chunk_size
                            ]
431
432
433
434
435
436
437
438
439
440
                        # Synthesize an expert-style name so expert mapping
                        # can route it
                        chunk_name = name.replace(
                            "mlp.shared_experts",
                            f"mlp.experts.{self.config.n_routed_experts + j}",
                        )

                    # Use expert_params_mapping to locate the destination
                    # param and delegate to its expert-aware weight_loader
                    # with expert_id.
441
                    is_expert_weight = False
442
443
444
445
446
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in chunk_name:
                            continue

447
448
449
450
                        # Anyway, this is an expert weight and should not be
                        # attempted to load as other weights later
                        is_expert_weight = True

451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
                        # Do not modify `name` since the loop may continue here
                        # Instead, create a new variable
                        name_mapped = chunk_name.replace(weight_name, param_name)

                        param = params_dict[name_mapped]
                        # We should ask the weight loader to return success or
                        # not here since otherwise we may skip experts with
                        # other available replicas.
                        weight_loader = typing.cast(
                            Callable[..., bool], param.weight_loader
                        )
                        success = weight_loader(
                            param,
                            weight_to_load,
                            name_mapped,
                            shard_id=shard_id,
                            expert_id=expert_id,
                            return_success=True,
                        )
                        if success:
                            if not is_fusion_moe_shared_experts_layer:
                                name = name_mapped
                            else:
                                loaded_params.add(name_mapped)
                            break
                    else:
477
478
479
480
481
482
                        if is_expert_weight:
                            # We've checked that this is an expert weight
                            # However it's not mapped locally to this rank
                            # So we simply skip it
                            continue

483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue

                        name = maybe_remap_kv_scale_name(name, params_dict)
                        if name is None:
                            continue

                        # According to DeepSeek-V3 Technical Report, MTP modules
                        # shares embedding layer. We only load the first weights.
                        if (
                            spec_layer != self.model.mtp_start_layer_idx
                            and ".layers" not in name
                        ):
                            continue

                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
            if not is_fusion_moe_shared_experts_layer:
                loaded_params.add(name)
506
        
zhuwenwen's avatar
zhuwenwen committed
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "self_attn.eh_proj.weight",
                "self_attn.q_proj.weight",
                "self_attn.q_a_proj.weight",
                "self_attn.q_b_proj.weight",
                "self_attn.kv_a_proj_with_mqa.weight",
                "self_attn.kv_b_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
                "mlp.gate.weight",
                "shared_experts.gate_up_proj.weight",
                "shared_experts.down_proj.weight",
                "shared_head.head.weight",
            ]

            combined_words = "|".join(lay_key_words)
            
            for layername in loaded_params:
                weight = params_dict[layername]
                matches = re.findall(combined_words, layername)
                if matches:
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
537
538
539
540
541
542
        return loaded_params

    def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
        """
        Rewrite the weight name to match the format of the original model.
        Add .mtp_block for modules in transformer layer block for spec layer
543
        and rename shared layer weights to be top level.
544
545
        """
        spec_layer_weight_names = [
546
547
548
549
550
            "embed_tokens",
            "enorm",
            "hnorm",
            "eh_proj",
            "shared_head",
551
        ]
552
        shared_weight_names = ["embed_tokens"]
553
        spec_layer_weight = False
554
        shared_weight = False
555
556
557
        for weight_name in spec_layer_weight_names:
            if weight_name in name:
                spec_layer_weight = True
558
559
                if weight_name in shared_weight_names:
                    shared_weight = True
560
561
562
                break
        if not spec_layer_weight:
            # treat rest weights as weights for transformer layer block
563
564
565
            name = name.replace(
                f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
            )
566
567
568
        elif shared_weight:
            # treat shared weights as top level weights
            name = name.replace(f"model.layers.{spec_layer}.", "model.")
569
        return name