deepseek_v2.py 72.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

wangding zeng's avatar
wangding zeng committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
25
"""Inference-only DeepseekV2/DeepseekV3 model."""
王敏's avatar
王敏 committed
26
27
import os
import re
28
import vllm.envs as envs
zhuwenwen's avatar
zhuwenwen committed
29

30
31
import typing
from collections.abc import Callable, Iterable
32
from typing import Any, Optional, Union, Tuple
wangding zeng's avatar
wangding zeng committed
33
34
35
36
37

import torch
from torch import nn
from transformers import PretrainedConfig

38
from vllm.attention import Attention
39
from vllm.compilation.decorators import support_torch_compile
40
41
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
                         get_current_vllm_config)
王敏's avatar
王敏 committed
42
from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
王敏's avatar
王敏 committed
43
44
45
46
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_gather,
                              get_tensor_model_parallel_rank,
                              tensor_model_parallel_reduce_scatter)
wangding zeng's avatar
wangding zeng committed
47
from vllm.model_executor.layers.activation import SiluAndMul
王敏's avatar
王敏 committed
48
49
50
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE

from vllm.model_executor.layers.fused_moe.utils import EPSharedExperts
wangding zeng's avatar
wangding zeng committed
51
52
53
54
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               ReplicatedLinear,
55
                                               FusedQuantedReplicatedLinear,
wangding zeng's avatar
wangding zeng committed
56
57
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
58
from vllm.model_executor.layers.quantization import QuantizationConfig
zhuwenwen's avatar
zhuwenwen committed
59
from vllm.model_executor.layers.rotary_embedding import get_rope
wangding zeng's avatar
wangding zeng committed
60
61
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
62
63
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
wangding zeng's avatar
wangding zeng committed
64
from vllm.model_executor.sampling_metadata import SamplingMetadata
65
from vllm.sequence import IntermediateTensors
wangding zeng's avatar
wangding zeng committed
66

67
from .interfaces import MixtureOfExperts, SupportsPP
68
from .utils import (PPMissingLayer, is_pp_missing_parameter,
69
70
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
王敏's avatar
王敏 committed
71
from vllm import _custom_ops as ops
72
from vllm.utils import W8a8GetCacheJSON
73

wangding zeng's avatar
wangding zeng committed
74
75
76
77
78
79
80
81
82
class DeepseekV2MLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
83
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
84
85
86
87
88
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
89
90
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
wangding zeng's avatar
wangding zeng committed
91
92
93
94
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           quant_config=quant_config,
95
96
                                           reduce_results=reduce_results,
                                           prefix=f"{prefix}.down_proj")
wangding zeng's avatar
wangding zeng committed
97
98
99
100
101
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

102
103
104
    def forward(self, x,
                rms_weight: Optional[torch.Tensor] = None,
                residual: Optional[torch.Tensor] = None,
105
                update_hd: Optional[bool] = False,
106
107
108
                xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
                ) -> Union[torch.Tensor,
                           Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
109
        if envs.USE_FUSED_RMS_QUANT:
110
            gate_up, new_resi, i_q, _scales, _  = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
111
112
113
114
115
            if envs.USE_FUSED_SILU_MUL_QUANT:
                x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
            else:
                x = self.act_fn(gate_up)
                x, _ = self.down_proj(x)
116

117
            return x, new_resi, i_q, _scales
118
119
        elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
            gate_up, _ = self.gate_up_proj(x, xqxs=xqxs)
120
121
122
123
124
            if envs.USE_FUSED_SILU_MUL_QUANT:
                x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
            else:
                x = self.act_fn(gate_up)
                x, _ = self.down_proj(x)
125
            return x
126
127
128
129
130
        else:
            gate_up, _ = self.gate_up_proj(x)
            x = self.act_fn(gate_up)
            x, _ = self.down_proj(x)
            return x
wangding zeng's avatar
wangding zeng committed
131
132
133
134
135
136
137
138


class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
139
        prefix: str = "",
140
        enable_eplb: bool = False,
wangding zeng's avatar
wangding zeng committed
141
142
143
144
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
145
146
147
148
149
150

        self.ep_group = get_ep_group().device_group
        self.ep_rank = self.ep_group.rank()
        self.ep_size = self.ep_group.size()
        self.n_routed_experts: int = config.n_routed_experts
        self.n_shared_experts: int = config.n_shared_experts
151
152
153
154
155

        if config.hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {config.hidden_act}. "
                             "Only silu is supported for now.")

wangding zeng's avatar
wangding zeng committed
156
        self.gate = ReplicatedLinear(config.hidden_size,
157
                                     config.n_routed_experts,
wangding zeng's avatar
wangding zeng committed
158
                                     bias=False,
159
160
                                     quant_config=None,
                                     prefix=f"{prefix}.gate")
161
162
163
164
165
166
        if config.topk_method == "noaux_tc":
            self.gate.e_score_correction_bias = nn.Parameter(
                torch.empty(config.n_routed_experts))
        else:
            self.gate.e_score_correction_bias = None

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        # Load balancing settings.
        vllm_config = get_current_vllm_config()
        parallel_config = vllm_config.parallel_config
        self.enable_eplb = enable_eplb

        self.n_redundant_experts = parallel_config.num_redundant_experts
        self.n_logical_experts = self.n_routed_experts
        self.n_physical_experts = (self.n_logical_experts +
                                   self.n_redundant_experts)
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

        self.physical_expert_start = (self.ep_rank *
                                      self.n_local_physical_experts)
        self.physical_expert_end = (self.physical_expert_start +
                                    self.n_local_physical_experts)
182

王敏's avatar
王敏 committed
183
184
185
186
        dp_size = get_dp_group().world_size
        self.enable_expert_parallel = parallel_config.enable_expert_parallel
        self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \
            (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
yangql's avatar
yangql committed
187
188
             envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
             envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
王敏's avatar
王敏 committed
189
        
190
        self.enable_shared_experts_overlap = False
王敏's avatar
王敏 committed
191
192
193
194
195
196
197
198
199
        if not self.use_deepep:
            if config.n_shared_experts is not None:
                intermediate_size = (config.moe_intermediate_size *
                                    config.n_shared_experts)
                self.shared_experts = DeepseekV2MLP(
                    hidden_size=config.hidden_size,
                    intermediate_size=intermediate_size,
                    hidden_act=config.hidden_act,
                    quant_config=quant_config,
200
                    reduce_results = False,
王敏's avatar
王敏 committed
201
202
                    prefix=f"{prefix}.shared_experts",
                )
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
                self.enable_shared_experts_overlap = (not envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM
                    and not envs.USE_FUSED_RMS_QUANT
                    and not envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
                    and config.n_shared_experts is not None)

            if self.enable_shared_experts_overlap:
                self.experts = SharedFusedMoE(
                    shared_experts=self.shared_experts,
                    num_experts=config.n_routed_experts,
                    top_k=config.num_experts_per_tok,
                    hidden_size=config.hidden_size,
                    intermediate_size=config.moe_intermediate_size,
                    reduce_results=False,
                    renormalize=config.norm_topk_prob,
                    quant_config=quant_config,
                    use_grouped_topk=True,
                    num_expert_group=config.n_group,
                    topk_group=config.topk_group,
                    prefix=f"{prefix}.experts",
                    scoring_func=config.scoring_func,
                    e_score_correction_bias=self.gate.e_score_correction_bias,
                    enable_eplb=self.enable_eplb,
                    num_redundant_experts=self.n_redundant_experts,
                    routed_scaling_factor=self.routed_scaling_factor)
            else:
                self.experts = FusedMoE(
                    num_experts=config.n_routed_experts,
                    top_k=config.num_experts_per_tok,
                    hidden_size=config.hidden_size,
                    intermediate_size=config.moe_intermediate_size,
                    reduce_results=False,
                    renormalize=config.norm_topk_prob,
                    quant_config=quant_config,
                    use_grouped_topk=True,
                    num_expert_group=config.n_group,
                    topk_group=config.topk_group,
                    prefix=f"{prefix}.experts",
                    scoring_func=config.scoring_func,
                    e_score_correction_bias=self.gate.e_score_correction_bias,
                    enable_eplb=self.enable_eplb,
                    num_redundant_experts=self.n_redundant_experts,
                    routed_scaling_factor=self.routed_scaling_factor)
王敏's avatar
王敏 committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        else:
            if config.n_shared_experts is not None:
                intermediate_size = (config.moe_intermediate_size *
                                    config.n_shared_experts)
                self.shared_experts = EPSharedExperts(
                    hidden_size=config.hidden_size,
                    intermediate_size=intermediate_size,
                    hidden_act=config.hidden_act,
                    quant_config=quant_config,
                    reduce_results=False,
                    prefix=f"{prefix}.shared_experts",
                )
            self.experts = SharedFusedMoE(
                num_experts=config.n_routed_experts,
                top_k=config.num_experts_per_tok,
                hidden_size=config.hidden_size,
                intermediate_size=config.moe_intermediate_size,
                reduce_results=False,
                renormalize=config.norm_topk_prob,
                quant_config=quant_config,
                use_grouped_topk=True,
                num_expert_group=config.n_group,
                topk_group=config.topk_group,
                prefix=f"{prefix}.experts",
                scoring_func=config.scoring_func,
                e_score_correction_bias=self.gate.e_score_correction_bias,
                enable_eplb=self.enable_eplb,
                num_redundant_experts=self.n_redundant_experts,
                routed_scaling_factor=self.routed_scaling_factor,
                shared_experts=self.shared_experts)
王敏's avatar
王敏 committed
275

276
277
        self.run_shared_expert_singlely = (self.n_shared_experts is not None and not self.enable_shared_experts_overlap)

278
        from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
lizhigong's avatar
lizhigong committed
279
        self.tbo_all_reduce = tbo_all_reduce
wangding zeng's avatar
wangding zeng committed
280

281
282
    def forward(self, hidden_states: torch.Tensor,
                rms_weight: Optional[torch.Tensor] = None,
283
284
                residual: Optional[torch.Tensor] = None,
                xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
285
286
                ) -> Union[torch.Tensor,
                           Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
王敏's avatar
王敏 committed
287
288
289
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)

290
291
292
293
294
295
296
297
298
        def shared_exprts_overlap_pass(
                hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
            hidden_states_copy = hidden_states.clone()
            return self.experts(
                    hidden_states=hidden_states,
                    router_logits=router_logits,
                    hidden_states_copy = hidden_states_copy)


299
        if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
300
            if self.n_shared_experts is not None and not self.enable_shared_experts_overlap:
301
                shared_output = self.shared_experts(hidden_states, xqxs=xqxs)
302

303
304
305
306
307
308
309
            router_logits, _ = self.gate(hidden_states)

            if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
                final_hidden_states = self.experts(
                    hidden_states=hidden_states,
                    router_logits=router_logits,
                    shared_output=shared_output)
王敏's avatar
王敏 committed
310
            else:
311
312
313
                if self.enable_shared_experts_overlap:
                    assert self.shared_experts is not None
                    shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
314
315
316
                    # Fix FP16 overflow
                    # See DeepseekV2DecoderLayer for more details.
                    if hidden_states.dtype != torch.float16:
317
318
319
320
321
322
323
324
325
326
                        final_hidden_states *= self.routed_scaling_factor
                        final_hidden_states += shared_output
                    else:
                        assert shared_output is not None
                        final_hidden_states += (shared_output * (1.0 / self.routed_scaling_factor))
                else:
                    if hidden_states.dtype != torch.float16:
                        final_hidden_states = self.experts(
                            hidden_states=hidden_states,
                            router_logits=router_logits) * self.routed_scaling_factor
327
328
329
                    else:
                        # Fix FP16 overflow
                        # See DeepseekV2DecoderLayer for more details.
330
331
                        final_hidden_states = self.experts(hidden_states=hidden_states,
                                                        router_logits=router_logits)
332

333
334
335
336
337
338
339
340
                    if shared_output is not None:
                        if hidden_states.dtype != torch.float16:
                            final_hidden_states = final_hidden_states + shared_output
                        else:
                            # Fix FP16 overflow
                            # See DeepseekV2DecoderLayer for more details.
                            final_hidden_states = final_hidden_states + shared_output \
                                * (1. / self.routed_scaling_factor)
王敏's avatar
王敏 committed
341
        else:
342
            if not self.enable_expert_parallel:
王敏's avatar
王敏 committed
343
                i_q, i_s = None, None
344
                if self.run_shared_expert_singlely:
王敏's avatar
王敏 committed
345
346
347
348
                    if envs.USE_FUSED_RMS_QUANT:
                        shared_output, new_resi, i_q, i_s = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
                    else:
                        shared_output = self.shared_experts(hidden_states)
349

王敏's avatar
王敏 committed
350
                router_logits, _ = self.gate(hidden_states)
351

352
353
354
355
356
                if self.enable_shared_experts_overlap:
                    assert self.shared_experts is not None
                    shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
                    # Fix FP16 overflow
                    # See DeepseekV2DecoderLayer for more details.
357
                    if hidden_states.dtype != torch.float16:
358
359
360
361
362
363
364
                        final_hidden_states *= self.routed_scaling_factor
                        final_hidden_states += shared_output
                    else:
                        assert shared_output is not None
                        final_hidden_states += (shared_output * (1.0 / self.routed_scaling_factor))
                else:
                    if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
王敏's avatar
王敏 committed
365
366
                        final_hidden_states = self.experts(
                            hidden_states=hidden_states,
367
368
369
                            router_logits=router_logits,
                            shared_output=shared_output,
                            i_q=i_q, i_s=i_s)
370
                    else:
王敏's avatar
王敏 committed
371
                        if hidden_states.dtype != torch.float16:
372
373
374
375
                            final_hidden_states = self.experts(
                                hidden_states=hidden_states,
                                router_logits=router_logits,
                                i_q=i_q, i_s=i_s) * self.routed_scaling_factor
王敏's avatar
王敏 committed
376
377
378
                        else:
                            # Fix FP16 overflow
                            # See DeepseekV2DecoderLayer for more details.
379
380
381
382
383
384
385
386
387
388
389
390
                            # fp16 mode not fused quant
                            final_hidden_states = self.experts(hidden_states=hidden_states,
                                                            router_logits=router_logits)
                    
                        if shared_output is not None:
                            if hidden_states.dtype != torch.float16:
                                final_hidden_states = final_hidden_states + shared_output
                            else:
                                # Fix FP16 overflow
                                # See DeepseekV2DecoderLayer for more details.
                                final_hidden_states = final_hidden_states + shared_output \
                                    * (1. / self.routed_scaling_factor)
王敏's avatar
王敏 committed
391
392
393
            else:
                router_logits, _ = self.gate(hidden_states)
                if self.use_deepep:
394
                    shared_output, final_hidden_states = self.experts(hidden_states=hidden_states,
王敏's avatar
王敏 committed
395
396
397
398
399
400
401
402
403
404
405
                                                router_logits=router_logits)

                    if shared_output is not None:
                        if hidden_states.dtype != torch.float16:
                            final_hidden_states = final_hidden_states + shared_output
                        else:
                            # Fix FP16 overflow
                            # See DeepseekV2DecoderLayer for more details.
                            final_hidden_states = final_hidden_states + shared_output \
                                * (1. / self.routed_scaling_factor)
                else:
406
                    if self.run_shared_expert_singlely:
王敏's avatar
王敏 committed
407
408
409
410
411
                        if envs.USE_FUSED_RMS_QUANT:
                            shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
                        else:
                            shared_output = self.shared_experts(hidden_states)

412
413
414
415
416
                    if self.enable_shared_experts_overlap:
                        assert self.shared_experts is not None
                        shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
                        # Fix FP16 overflow
                        # See DeepseekV2DecoderLayer for more details.
王敏's avatar
王敏 committed
417
                        if hidden_states.dtype != torch.float16:
418
                            final_hidden_states += shared_output
王敏's avatar
王敏 committed
419
                        else:
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
                            assert shared_output is not None
                            final_hidden_states += (shared_output * (1. / self.routed_scaling_factor))
                    else:
                        final_hidden_states = self.experts(
                            hidden_states=hidden_states,
                            router_logits=router_logits)

                        if shared_output is not None:
                            if hidden_states.dtype != torch.float16:
                                final_hidden_states = final_hidden_states + shared_output
                            else:
                                # Fix FP16 overflow
                                # See DeepseekV2DecoderLayer for more details.
                                final_hidden_states = final_hidden_states + shared_output \
                                    * (1. / self.routed_scaling_factor)

        if self.tp_size > 1:
            if envs.VLLM_ENABLE_TBO:
                final_hidden_states = self.tbo_all_reduce(final_hidden_states)
lizhigong's avatar
lizhigong committed
439
            else:
440
441
442
443
444
445
446
447
                final_hidden_states = (
                    self.experts.maybe_all_reduce_tensor_model_parallel(
                        final_hidden_states))

        if envs.USE_FUSED_RMS_QUANT:
            return final_hidden_states.view(num_tokens, hidden_dim), new_resi, i_q, i_s
        else:
            return final_hidden_states.view(num_tokens, hidden_dim)
wangding zeng's avatar
wangding zeng committed
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


class DeepseekV2Attention(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
470
        rope_scaling: Optional[dict[str, Any]] = None,
wangding zeng's avatar
wangding zeng committed
471
472
473
        max_position_embeddings: int = 8192,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
474
        prefix: str = "",
王敏's avatar
王敏 committed
475
        reduce_results: bool = True,
wangding zeng's avatar
wangding zeng committed
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
            self.q_a_proj = ReplicatedLinear(self.hidden_size,
                                             self.q_lora_rank,
                                             bias=False,
497
498
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.q_a_proj")
wangding zeng's avatar
wangding zeng committed
499
500
501
502
503
504
            self.q_a_layernorm = RMSNorm(self.q_lora_rank,
                                         eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(q_lora_rank,
                                                 self.num_heads *
                                                 self.qk_head_dim,
                                                 bias=False,
505
506
                                                 quant_config=quant_config,
                                                 prefix=f"{prefix}.q_b_proj")
wangding zeng's avatar
wangding zeng committed
507
508
509
510
511
        else:
            self.q_proj = ColumnParallelLinear(self.hidden_size,
                                               self.num_heads *
                                               self.qk_head_dim,
                                               bias=False,
512
513
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.q_proj")
wangding zeng's avatar
wangding zeng committed
514

515
516
517
518
519
520
        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_a_proj_with_mqa")
wangding zeng's avatar
wangding zeng committed
521
522
523
524
525
526
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
                                      eps=config.rms_norm_eps)
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
527
528
            quant_config=quant_config,
            prefix=f"{prefix}.kv_b_proj")
wangding zeng's avatar
wangding zeng committed
529
530
531
532
        # O projection.
        self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
                                        self.hidden_size,
                                        bias=False,
533
                                        quant_config=quant_config,
王敏's avatar
王敏 committed
534
535
                                        prefix=f"{prefix}.o_proj",
                                        reduce_results=reduce_results)
536
537
        if rope_scaling:
            rope_scaling["rope_type"] = 'deepseek_yarn'
538

wangding zeng's avatar
wangding zeng committed
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        self.rotary_emb = get_rope(qk_rope_head_dim,
                                   rotary_dim=qk_rope_head_dim,
                                   max_position=max_position_embeddings,
                                   base=rope_theta,
                                   rope_scaling=rope_scaling,
                                   is_neox_style=False)

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

        self.attn = Attention(self.num_local_heads,
553
                              self.qk_head_dim,
wangding zeng's avatar
wangding zeng committed
554
555
556
                              self.scaling,
                              num_kv_heads=self.num_local_heads,
                              cache_config=cache_config,
557
558
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
wangding zeng's avatar
wangding zeng committed
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
                                         self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
                                                   self.qk_head_dim)
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
                               dim=-1)
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
        kv_a, _ = latent_cache.split(
            [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads,
                     self.qk_nope_head_dim + self.v_head_dim)
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
        k_pe = latent_cache[:, :, self.kv_lora_rank:]
585

wangding zeng's avatar
wangding zeng committed
586
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
587

wangding zeng's avatar
wangding zeng committed
588
589
590
591
        q[..., self.qk_nope_head_dim:] = q_pe
        k = torch.empty_like(q)
        k[..., :self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim:] = k_pe
592
593
594
595
        # padding value to qk_head_dim for alignment
        v = torch.nn.functional.pad(
            v, [0, self.qk_head_dim - self.v_head_dim],
            value=0).view(-1, self.num_local_heads * self.qk_head_dim)
596
        attn_output = self.attn(q, k, v)
wangding zeng's avatar
wangding zeng committed
597
        attn_output = attn_output.view(
598
599
            -1, self.num_local_heads,
            self.qk_head_dim)[..., :self.v_head_dim].reshape(
wangding zeng's avatar
wangding zeng committed
600
601
602
603
604
                -1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output


605
606
607
608
class DeepseekV2MLAAttention(nn.Module):
    """
    Main reference: DeepseekV2 paper, and FlashInfer Implementation
    (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
609

610
611
612
613
614
615
616
617
618
619
620
621
622
623
    For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
    """

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: Optional[int],
        kv_lora_rank: int,
        rope_theta: float = 10000,
624
        rope_scaling: Optional[dict[str, Any]] = None,
625
626
627
628
        max_position_embeddings: int = 8192,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
王敏's avatar
王敏 committed
629
        reduce_results: bool = True,
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank

        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size

        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
651
            if envs.USE_FUSED_RMS_QUANT:
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
                if envs.VLLM_USE_FUSED_QA_KVA_GEMM:
                    self.qa_kva_proj = FusedQuantedReplicatedLinear(self.hidden_size,
                                                             self.q_lora_rank,
                                                             self.kv_lora_rank,
                                                             self.qk_rope_head_dim,
                                                             bias=False,
                                                             quant_config=quant_config,
                                                             prefix=f"{prefix}.qa_kva_proj")
                else:
                    self.q_a_proj = ReplicatedLinear(self.hidden_size,
                                                self.q_lora_rank,
                                                bias=False,
                                                quant_config=quant_config,
                                                eps=config.rms_norm_eps,
                                                prefix=f"{prefix}.q_a_proj")
667
                self.q_b_proj = ColumnParallelLinear(q_lora_rank,
668
669
670
671
                                                 self.num_heads *
                                                 self.qk_head_dim,
                                                 bias=False,
                                                 quant_config=quant_config,
672
                                                 eps=config.rms_norm_eps,
673
                                                 prefix=f"{prefix}.q_b_proj")
674
675
676
677
678
679
680
681
682
683
684
685
            else:
                self.q_a_proj = ReplicatedLinear(self.hidden_size,
                                             self.q_lora_rank,
                                             bias=False,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.q_a_proj")
                self.q_b_proj = ColumnParallelLinear(q_lora_rank,
                                                 self.num_heads *
                                                 self.qk_head_dim,
                                                 bias=False,
                                                 quant_config=quant_config,
                                                 prefix=f"{prefix}.q_b_proj")
686

687
688
689
            self.q_a_layernorm = RMSNorm(self.q_lora_rank,
                                         eps=config.rms_norm_eps)

690
691
692
693
694
695
696
        else:
            self.q_proj = ColumnParallelLinear(self.hidden_size,
                                               self.num_heads *
                                               self.qk_head_dim,
                                               bias=False,
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.q_proj")
697
698
699
700
701
702
703
        if not envs.VLLM_USE_FUSED_QA_KVA_GEMM:
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.kv_a_proj_with_mqa")
704
705
706
707
708
709
710
711
712
713
714
715
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
                                      eps=config.rms_norm_eps)
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_b_proj")
        self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
                                        self.hidden_size,
                                        bias=False,
                                        quant_config=quant_config,
王敏's avatar
王敏 committed
716
717
                                        prefix=f"{prefix}.o_proj",
                                        reduce_results=reduce_results)
718

719
720
        if rope_scaling:
            rope_scaling["rope_type"] = 'deepseek_yarn'
721
722
723
724
725
726
727
728
729
730
731
732
        self.rotary_emb = get_rope(qk_rope_head_dim,
                                   rotary_dim=qk_rope_head_dim,
                                   max_position=max_position_embeddings,
                                   base=rope_theta,
                                   rope_scaling=rope_scaling,
                                   is_neox_style=False)
        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

733
734
735
736
737
738
        # In the MLA backend, kv_cache includes both k_c and
        # pe (i.e. decoupled position embeddings). In particular,
        # the concat_and_cache_mla op requires
        #     k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
        # i.e.
        #     kv_lora_rank + qk_rope_head_dim == head_size
739
740
        self.mla_attn = Attention(
            num_heads=self.num_local_heads,
741
            head_size=self.kv_lora_rank + self.qk_rope_head_dim,
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
            scale=self.scaling,
            num_kv_heads=1,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_mla=True,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=self.kv_b_proj,
        )

        self.prefix = prefix
        self.debug_layer_idx = int(self.prefix.split(".")[-2])
760
761
762
        
    
    # TODO wjl: 这里的forward拆了
763
764
765
766
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
767
        rms_weight: Optional[torch.Tensor] = None,
768
769
770
771
772
773
774
775
        residual: Optional[torch.Tensor] = None,
        pa_rms_weight: Optional[torch.Tensor] = None,
        pa_residual: Optional[torch.Tensor] = None,
        pa_rms_eps: Optional[float] = 1e-6,
        pa_quant_dtype: Optional[torch.dtype] = torch.int8,
        update_input: Optional[bool] = True
    ) -> Union[torch.Tensor,
               Tuple[torch.Tensor, torch.Tensor],
776
               Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
777
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
778
779
780
781
782
783
784
785
786
787
            if envs.VLLM_USE_FUSED_QA_KVA_GEMM:
                if self.q_lora_rank is not None:
                    qc_kvc_kpe, new_residual, _bias = self.qa_kva_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
                    q_c = qc_kvc_kpe[:, :self.q_lora_rank]
                    kvc_kpe = qc_kvc_kpe[:, self.q_lora_rank:]
                    q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False)
                    
                else:
                    q = self.q_proj(hidden_states)[0]
                kv_c, k_pe = kvc_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
788
            else:
789
790
791
792
793
794
795
796
                if self.q_lora_rank is not None:
                    q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
                    q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False)
                    
                else:
                    q = self.q_proj(hidden_states)[0]
                kvc_kpe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0]
                kv_c, k_pe = kvc_kpe.split(
797
                                [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
798

799
800
801
802
803
            if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
                if envs.VLLM_USE_LIGHTOP:
                    kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
                else:
                    kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
804

805
806
807
                q = q.view(-1, self.num_local_heads, self.qk_head_dim)
                # Add head dim of 1 to k_pe
                k_pe = k_pe.unsqueeze(1)
808

809
810
                q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
                    positions, q[..., self.qk_nope_head_dim:], k_pe)
811

812
813
814
815
816
817
818
819
820
821
822
823
824
825
                attn_out = self.mla_attn(
                    q,
                    kv_c_normed,
                    k_pe,
                    output_shape=(hidden_states.shape[0],
                                self.num_local_heads * self.v_head_dim))
            else:
                q = q.view(-1, self.num_local_heads, self.qk_head_dim)
                # Add head dim of 1 to k_pe
                k_pe = k_pe.unsqueeze(1)
                weight = self.kv_a_layernorm.weight
                cos_sin_cache = self.rotary_emb.cos_sin_cache
                if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
                    cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
826
                kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
827
828
829
830
831
832
833
834
835
836
837
                attn_out = self.mla_attn(
                    q[..., self.qk_nope_head_dim:],
                    kv_c,
                    k_pe,
                    output_shape=(hidden_states.shape[0],
                                self.num_local_heads * self.v_head_dim),
                    q_ori=q,
                    key_normed=kv_c_normed,
                    positions=positions,
                    weight=weight,
                    cos_sin_cache=cos_sin_cache)
838
            return self.o_proj(attn_out)[0], new_residual
839
840
841
842
843
844
845
846
847
        elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None:
            if self.q_lora_rank is not None:
                q_c = self.q_a_proj(hidden_states)[0]
                q_c = self.q_a_layernorm(q_c)
                q = self.q_b_proj(q_c)[0]
            else:
                q = self.q_proj(hidden_states)[0]
            kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
                [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
848
849
850
            if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
                if envs.VLLM_USE_LIGHTOP:
                    kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
851
                else:
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
                    kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())

                q = q.view(-1, self.num_local_heads, self.qk_head_dim)
                k_pe = k_pe.unsqueeze(1)

                q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
                    positions, q[..., self.qk_nope_head_dim:], k_pe)

                attn_out = self.mla_attn(
                    q,
                    kv_c_normed,
                    k_pe,
                    output_shape=(hidden_states.shape[0],
                                self.num_local_heads * self.v_head_dim))
            else:
                q = q.view(-1, self.num_local_heads, self.qk_head_dim)
                # Add head dim of 1 to k_pe
                k_pe = k_pe.unsqueeze(1)
                weight = self.kv_a_layernorm.weight
                cos_sin_cache = self.rotary_emb.cos_sin_cache
                if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
                    cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
874
                kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
875
876
877
878
879
880
881
882
883
884
885
                attn_out = self.mla_attn(
                    q[..., self.qk_nope_head_dim:],
                    kv_c,
                    k_pe,
                    output_shape=(hidden_states.shape[0],
                                self.num_local_heads * self.v_head_dim),
                    q_ori=q,
                    key_normed=kv_c_normed,
                    positions=positions,
                    weight=weight,
                    cos_sin_cache=cos_sin_cache)
886
            packages_ = self.o_proj(attn_out,
887
888
889
890
891
892
893
894
895
896
                                   pa_rms_weight=pa_rms_weight,
                                   pa_residual=pa_residual,
                                   pa_rms_eps=pa_rms_eps,
                                   pa_quant_dtype=pa_quant_dtype,
                                   update_input=update_input)[:4]
            assert len(packages_) == 4
            hs, resi, xq, xs = packages_
            assert xq is not None and xs is not None
            return hs, resi, xq, xs

897
        else:
898
899
900
901
902
903
904
905
            if self.q_lora_rank is not None:
                q_c = self.q_a_proj(hidden_states)[0]
                q_c = self.q_a_layernorm(q_c)
                q = self.q_b_proj(q_c)[0]
            else:
                q = self.q_proj(hidden_states)[0]
            kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
                [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
906
907
908
909
910
            if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
                if envs.VLLM_USE_LIGHTOP:
                    kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
                else:
                    kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
911

912
913
914
                q = q.view(-1, self.num_local_heads, self.qk_head_dim)
                # Add head dim of 1 to k_pe
                k_pe = k_pe.unsqueeze(1)
915

916
917
                q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
                    positions, q[..., self.qk_nope_head_dim:], k_pe)
918

919
920
921
922
923
924
925
926
927
928
                attn_out = self.mla_attn(
                    q,
                    kv_c_normed,
                    k_pe,
                    output_shape=(hidden_states.shape[0],
                                self.num_local_heads * self.v_head_dim))
            else:
                q = q.view(-1, self.num_local_heads, self.qk_head_dim)
                # Add head dim of 1 to k_pe
                k_pe = k_pe.unsqueeze(1)
zhuwenwen's avatar
zhuwenwen committed
929
930
931
932
                weight = self.kv_a_layernorm.weight
                cos_sin_cache = self.rotary_emb.cos_sin_cache
                if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
                    cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
933
                kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
934
935
936
937
938
939
940
941
942
                attn_out = self.mla_attn(
                    q[..., self.qk_nope_head_dim:],
                    kv_c,
                    k_pe,
                    output_shape=(hidden_states.shape[0],
                                self.num_local_heads * self.v_head_dim),
                    q_ori=q,
                    key_normed=kv_c_normed,
                    positions=positions,
zhuwenwen's avatar
zhuwenwen committed
943
944
                    weight=weight,
                    cos_sin_cache=cos_sin_cache)
945
            return self.o_proj(attn_out)[0]
946
947


wangding zeng's avatar
wangding zeng committed
948
949
950
951
952
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
953
        prefix: str,
954
        model_config: ModelConfig,
wangding zeng's avatar
wangding zeng committed
955
956
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
957
        enable_eplb: bool = False,
wangding zeng's avatar
wangding zeng committed
958
959
960
961
962
963
964
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
965
966
967
        # DecoderLayers are created with `make_layers` which passes the prefix
        # with the layer's index.
        layer_idx = int(prefix.split(sep='.')[-1])
968
        self.layer_idx = layer_idx
王敏's avatar
王敏 committed
969
970
971
972
973
974

        self.dp_size = get_dp_group().world_size
        vllm_config = get_current_vllm_config()
        parallel_config = vllm_config.parallel_config
        self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
            (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
yangql's avatar
yangql committed
975
976
             envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
             envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
王敏's avatar
王敏 committed
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
        self.tp_size = get_tensor_model_parallel_world_size()
        self.config = config
        self.tp_rank = get_tensor_model_parallel_rank()

        if (config.n_routed_experts is not None
                and layer_idx >= config.first_k_dense_replace
                and layer_idx % config.moe_layer_freq == 0):
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
                enable_eplb=enable_eplb,
            )
        else:
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )

        self.is_mtp_layer = False
        if self.layer_idx == config.num_hidden_layers:
            self.is_mtp_layer = True
        reduce_results = True
        if isinstance(self.mlp,
                        DeepseekV2MoE) and self.use_deepep and \
                            self.tp_size > 1 and not self.is_mtp_layer:
            reduce_results = False

1008
1009
1010
1011
1012
        if model_config.use_mla:
            attn_cls = DeepseekV2MLAAttention
        else:
            attn_cls = DeepseekV2Attention
        self.self_attn = attn_cls(
wangding zeng's avatar
wangding zeng committed
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=config.q_lora_rank
            if hasattr(config, "q_lora_rank") else None,
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
1027
            prefix=f"{prefix}.self_attn",
王敏's avatar
王敏 committed
1028
            reduce_results=reduce_results
wangding zeng's avatar
wangding zeng committed
1029
        )
1030

wangding zeng's avatar
wangding zeng committed
1031
1032
1033
1034
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
1035
        self.routed_scaling_factor = config.routed_scaling_factor
1036
1037
        self.use_fused_rms_quant = envs.USE_FUSED_RMS_QUANT
        self.use_fused_custom_all_reduce = envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
wangding zeng's avatar
wangding zeng committed
1038

1039

王敏's avatar
王敏 committed
1040

1041
    def forward_fused_rmsquant(
wangding zeng's avatar
wangding zeng committed
1042
1043
1044
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1045
1046
1047
1048
        residual: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Fix residual FP16 overflow
        residual_fix_overflow = False
1049

1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
        assert self.input_layernorm.has_weight is True
        if residual is None:
            residual = hidden_states
            hidden_states, _ = self.self_attn(
                positions = positions,
                hidden_states = hidden_states,
                rms_weight = self.input_layernorm.weight.data,
                residual = None
            )
            residual_fix_overflow = True
        else:
            hidden_states, new_residual = self.self_attn(
                positions = positions,
                hidden_states = hidden_states,
                rms_weight = self.input_layernorm.weight.data,
                residual = residual
            )
            residual = new_residual
1068

1069
1070
1071
1072
1073
1074
1075
1076
        if hidden_states.dtype == torch.float16:
            # rmsnorm, and rmsnorm result would not affect by scale.
            hidden_states *= 1. / self.routed_scaling_factor
            if self.layer_idx == 0 or residual_fix_overflow:
                # The residual is shared by all layers, we only scale it on
                # first layer.
                residual *= 1. / self.routed_scaling_factor

1077
1078
        hidden_states, new_resi, _i_q, _scales = self.mlp(hidden_states,
                                                         rms_weight=self.post_attention_layernorm.weight.data,
1079
1080
                                                         residual=residual,
                                                         )
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092

        if isinstance(self.mlp,
                    DeepseekV2MLP) and hidden_states.dtype == torch.float16:
            # Fix FP16 overflow
            # Scaling the DeepseekV2MLP output, it is the input of
            # input_layernorm of next decoder layer.
            # The scaling of DeepseekV2MOE output would be done in the forward
            # of DeepseekV2MOE
            hidden_states *= 1. / self.routed_scaling_factor
        return hidden_states, new_resi

    def forward_fused_CRQ(
1093
1094
1095
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
        residual: Optional[torch.Tensor]
        ) -> Tuple[torch.Tensor, torch.Tensor]:
        residual_fix_overflow = False
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
            residual_fix_overflow = True
        else:
            hidden_states, resi_new = self.input_layernorm(
                hidden_states, residual)
1106
            residual = resi_new
1107
1108
1109
        new_hs, new_resi, xq, xs = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
1110
            pa_rms_weight=self.post_attention_layernorm.weight.data,
1111
1112
1113
1114
1115
            pa_residual=residual,
            pa_rms_eps=self.post_attention_layernorm.variance_epsilon,
            pa_quant_dtype = torch.int8,
            update_input=True
        )
1116
1117


1118
1119
1120
1121
1122
        assert xq is not None and xs is not None
        if new_hs.dtype == torch.float16: # overflow处理逻辑
            new_hs *= 1. / self.routed_scaling_factor
            if self.layer_idx == 0 or residual_fix_overflow:
                new_resi *= 1. / self.routed_scaling_factor
1123

1124
1125
1126
1127
1128
1129
        hidden_states = self.mlp(new_hs, xqxs=(xq, xs))

        if isinstance(self.mlp,
                    DeepseekV2MLP) and hidden_states.dtype == torch.float16:
            hidden_states *= 1. / self.routed_scaling_factor
        return hidden_states, new_resi
1130

1131
    def forward_default(
1132
        self,
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor]
        ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        # Fix residual FP16 overflow
        residual_fix_overflow = False
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
            residual_fix_overflow = True
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
1147

王敏's avatar
王敏 committed
1148
1149
1150
1151
1152
1153
        if not self.is_mtp_layer:
            if isinstance(self.mlp,
                        DeepseekV2MoE) and self.use_deepep and self.tp_size > 1 and \
                            self.layer_idx > self.config.first_k_dense_replace:
                hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)

1154
1155
1156
1157
1158
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

王敏's avatar
王敏 committed
1159
1160
1161
1162
1163
1164
1165
1166
1167
        if not self.is_mtp_layer:
            if isinstance(self.mlp,
                        DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
                if self.layer_idx == self.config.first_k_dense_replace:
                    residual = residual.tensor_split(self.tp_size)[self.tp_rank]

                hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)


1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
        if hidden_states.dtype == torch.float16:
            # Fix FP16 overflow
            # We scale both hidden_states and residual before
            # rmsnorm, and rmsnorm result would not affect by scale.
            hidden_states *= 1. / self.routed_scaling_factor
            if self.layer_idx == 0 or residual_fix_overflow:
                # The residual is shared by all layers, we only scale it on
                # first layer.
                residual *= 1. / self.routed_scaling_factor

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
1181

王敏's avatar
王敏 committed
1182
1183
1184
        if self.is_mtp_layer:
            if isinstance(self.mlp,
                        DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
1185
1186
1187
1188
1189
1190
1191
                ori_bs = hidden_states.shape[0]
                pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
                if pad_size > 0:
                    hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous()
                new_bs = (ori_bs+pad_size) // self.tp_size
                hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous()

王敏's avatar
王敏 committed
1192

1193
1194
        hidden_states = self.mlp(hidden_states)

王敏's avatar
王敏 committed
1195
1196
1197
1198
        if self.is_mtp_layer:
            if isinstance(self.mlp,
                        DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
                hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
1199
                hidden_states = hidden_states[:ori_bs, :]
王敏's avatar
王敏 committed
1200

1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
        if isinstance(self.mlp,
                    DeepseekV2MLP) and hidden_states.dtype == torch.float16:
            # Fix FP16 overflow
            # Scaling the DeepseekV2MLP output, it is the input of
            # input_layernorm of next decoder layer.
            # The scaling of DeepseekV2MOE output would be done in the forward
            # of DeepseekV2MOE
            hidden_states *= 1. / self.routed_scaling_factor

        return hidden_states, residual
1211

1212
1213
1214
    def choose_forward(self):
        if self.use_fused_rms_quant:
            return self.forward_fused_rmsquant
1215

1216
1217
        elif self.use_fused_custom_all_reduce:
            return self.forward_fused_CRQ
wangding zeng's avatar
wangding zeng committed
1218
        else:
1219
            return self.forward_default
1220

1221
1222
1223
1224
1225
1226
1227
1228
1229
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor]
    )  -> Tuple[torch.Tensor, torch.Tensor]:
        forward_func = self.choose_forward()
        return forward_func(positions=positions, hidden_states=hidden_states, residual=residual )

wangding zeng's avatar
wangding zeng committed
1230

1231
@support_torch_compile
wangding zeng's avatar
wangding zeng committed
1232
1233
1234
1235
class DeepseekV2Model(nn.Module):

    fall_back_to_pt_during_load = False

1236
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
wangding zeng's avatar
wangding zeng committed
1237
        super().__init__()
1238
1239

        config = vllm_config.model_config.hf_config
1240
        model_config = vllm_config.model_config
1241
1242
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
1243
        enable_eplb = vllm_config.parallel_config.enable_eplb
1244
        self.config = config
1245

wangding zeng's avatar
wangding zeng committed
1246
1247
        self.vocab_size = config.vocab_size

1248
1249
1250
1251
        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
1252
1253
                quant_config=quant_config,
                prefix=f"{prefix}.embed_tokens")
1254
1255
1256
1257
1258
1259
1260
1261
        else:
            self.embed_tokens = PPMissingLayer()

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: DeepseekV2DecoderLayer(
                config,
                prefix,
1262
                model_config=model_config,
1263
1264
                cache_config=cache_config,
                quant_config=quant_config,
1265
                enable_eplb=enable_eplb,
1266
1267
1268
1269
1270
1271
1272
            ),
            prefix=f"{prefix}.layers")

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
1273
1274
1275
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
1276

王敏's avatar
王敏 committed
1277
1278
1279
1280
1281
        self.dp_size = get_dp_group().world_size
        vllm_config = get_current_vllm_config()
        parallel_config = vllm_config.parallel_config
        self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
            (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
yangql's avatar
yangql committed
1282
1283
             envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
             envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
王敏's avatar
王敏 committed
1284
        self.tp_size = get_tensor_model_parallel_world_size()
wangding zeng's avatar
wangding zeng committed
1285

1286
1287
1288
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

wangding zeng's avatar
wangding zeng committed
1289
1290
1291
1292
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1293
        intermediate_tensors: Optional[IntermediateTensors],
1294
        inputs_embeds: Optional[torch.Tensor] = None,
1295
    ) -> Union[torch.Tensor, IntermediateTensors]:
1296
        if get_pp_group().is_first_rank:
1297
1298
1299
1300
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
1301
1302
1303
1304
1305
1306
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1307
        for layer in self.layers[self.start_layer:self.end_layer]:
1308
            hidden_states, residual = layer(positions, hidden_states, residual)
1309
1310
1311
1312
1313
1314
1315

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

wangding zeng's avatar
wangding zeng committed
1316
        hidden_states, _ = self.norm(hidden_states, residual)
王敏's avatar
王敏 committed
1317
1318
1319
1320

        if self.use_deepep and self.tp_size > 1:
            hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)

wangding zeng's avatar
wangding zeng committed
1321
1322
1323
        return hidden_states


1324
class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
wangding zeng's avatar
wangding zeng committed
1325

1326
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
wangding zeng's avatar
wangding zeng committed
1327
        super().__init__()
1328
1329
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
1330
1331
1332
1333
1334
1335
1336

        self.quant_method = None
        if quant_config is not None:
            self.quant_method = quant_config.get_name()
            os.environ['LLAMA_NN'] = '0'
            os.environ['LM_NN'] = '0'

1337
        self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1'
wangding zeng's avatar
wangding zeng committed
1338
1339
        self.config = config
        self.quant_config = quant_config
王敏's avatar
王敏 committed
1340

1341
        self.model = DeepseekV2Model(vllm_config=vllm_config,
1342
                                     prefix=maybe_prefix(prefix, "model"))
1343
1344
1345
1346
1347
1348
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(config.vocab_size,
                                          config.hidden_size,
                                          quant_config=quant_config)
        else:
            self.lm_head = PPMissingLayer()
wangding zeng's avatar
wangding zeng committed
1349
        self.logits_processor = LogitsProcessor(config.vocab_size)
1350
1351
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
1352
1353
1354
1355
1356
1357
1358
1359
        self.expert_weights = []

        # Set MoE hyperparameters
        self.num_moe_layers = (config.num_hidden_layers -
                               config.first_k_dense_replace)
        self.num_expert_groups = config.n_group

        self.moe_layers: list[FusedMoE] = []
1360
        example_moe = None
1361
        for layer in self.model.layers:
1362
1363
1364
            if isinstance(layer, PPMissingLayer):
                continue

1365
1366
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
1367
                example_moe = layer.mlp
1368
1369
1370
1371
1372
1373
1374
1375
1376
                self.moe_layers.append(layer.mlp.experts)

        # Pick last one layer since the first ones may be dense layers.
        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_shared_experts = example_moe.n_shared_experts
        self.num_redundant_experts = example_moe.n_redundant_experts
1377

王敏's avatar
王敏 committed
1378
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
1379
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
1380
        self.tritonsingleton= W8a8GetCacheJSON()
1381
        self.tritonsingleton.topk = config.num_experts_per_tok
王敏's avatar
王敏 committed
1382
        self.tritonsingleton.quant_method=self.quant_method
王敏's avatar
王敏 committed
1383

1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
    def set_eplb_state(
        self,
        expert_load_view: torch.Tensor,
        logical_to_physical_map: torch.Tensor,
        logical_replica_count: torch.Tensor,
    ) -> None:
        for layer_idx, layer in enumerate(self.moe_layers):
            # Register the expert weights.
            self.expert_weights.append(layer.get_expert_weights())
            layer.set_eplb_state(
                moe_layer_idx=layer_idx,
                expert_load_view=expert_load_view,
                logical_to_physical_map=logical_to_physical_map,
                logical_replica_count=logical_replica_count,
            )
1399

1400
1401
1402
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

wangding zeng's avatar
wangding zeng committed
1403
1404
1405
1406
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1407
        intermediate_tensors: Optional[IntermediateTensors] = None,
1408
        inputs_embeds: Optional[torch.Tensor] = None,
1409
    ) -> Union[torch.Tensor, IntermediateTensors]:
1410
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
1411
                                   inputs_embeds)
wangding zeng's avatar
wangding zeng committed
1412
1413
        return hidden_states

1414
1415
1416
1417
1418
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
1419
        logits = self.logits_processor(self.lm_head, hidden_states,
wangding zeng's avatar
wangding zeng committed
1420
1421
1422
                                       sampling_metadata)
        return logits

1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
            "residual":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })
1436

1437
1438
1439
1440
    def restore_qzeros_tensor(self, qzeros, qscales):

        low_bits = qzeros & 0x0F
        high_bits = qzeros >> 4
1441

1442
1443
1444
1445
1446
1447
        zeors_tensor = torch.stack([low_bits, high_bits], dim=2).view(qzeros.shape[0], -1 , qzeros.shape[-1])
        zeors_int16 = zeors_tensor.to(torch.int16)
        assert zeors_int16.shape == qscales.shape

        uint16_tensor1 = zeors_int16.view(torch.uint16)
        uint16_tensor2 = qscales.view(torch.uint16)
1448

1449
1450
        uint32_tensor1 = uint16_tensor1.to(torch.int32) << 16
        uint32_tensor2 = uint16_tensor2.to(torch.int32)
1451

1452
1453
1454
1455
        result_tensor = uint32_tensor1 + uint32_tensor2
        result_tensor =result_tensor.view(torch.uint32)
        result_tensor = result_tensor.transpose(1, 2).contiguous()
        return result_tensor
1456

1457
1458
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
wangding zeng's avatar
wangding zeng committed
1459
1460
1461
1462
1463
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
1464
1465
1466
1467
1468
1469
        if envs.USE_FUSED_RMS_QUANT and envs.VLLM_USE_FUSED_QA_KVA_GEMM:
            fused_params_mapping = [
                ("qa_kva_proj", "q_a_proj", 0),
                ("qa_kva_proj", "kv_a_proj_with_mqa", 1)
            ]
            stacked_params_mapping += fused_params_mapping
wangding zeng's avatar
wangding zeng committed
1470

1471
1472
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
王敏's avatar
王敏 committed
1473
1474
1475
1476
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1477
1478
            num_experts=self.config.n_routed_experts,
            num_redundant_experts=self.num_redundant_experts)
1479

wangding zeng's avatar
wangding zeng committed
1480
        params_dict = dict(self.named_parameters())
1481
        loaded_params: set[str] = set()
wangding zeng's avatar
wangding zeng committed
1482
1483
1484
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
1485

1486
1487
1488
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is not None:
                continue  # skip spec decode layers for main model
1489

wangding zeng's avatar
wangding zeng committed
1490
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
1491
                # Skip non-stacked layers and experts (experts handled below).
wangding zeng's avatar
wangding zeng committed
1492
1493
                if weight_name not in name:
                    continue
1494
1495
1496
1497
1498
1499
1500
1501
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if (("mlp.experts." in name) and name not in params_dict):
                    continue
1502
                old_weight_name = name
wangding zeng's avatar
wangding zeng committed
1503
                name = name.replace(weight_name, param_name)
王敏's avatar
王敏 committed
1504

wangding zeng's avatar
wangding zeng committed
1505
1506
1507
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1508
1509
1510
1511

                if is_pp_missing_parameter(name, self):
                    continue

wangding zeng's avatar
wangding zeng committed
1512
1513
                param = params_dict[name]
                weight_loader = param.weight_loader
1514
1515
1516
1517
1518

                if envs.USE_FUSED_RMS_QUANT and envs.VLLM_USE_FUSED_QA_KVA_GEMM and (("q_a_proj"  in old_weight_name) or ("kv_a_proj_with_mqa" in old_weight_name)):
                    weight_loader(param, loaded_weight, old_weight_name)
                else:
                    weight_loader(param, loaded_weight, shard_id)
wangding zeng's avatar
wangding zeng committed
1519
1520
                break
            else:
1521
                is_expert_weight = False
1522
1523
1524
1525
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
1526

1527
1528
1529
1530
1531
1532
1533
1534
1535
                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name_mapped, self):
1536
1537
                        continue

1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
                    weight_loader = typing.cast(Callable[..., bool],
                                                param.weight_loader)
                    success = weight_loader(param,
                                            loaded_weight,
                                            name_mapped,
                                            shard_id=shard_id,
                                            expert_id=expert_id,
                                            return_success=True)
                    if success:
1551
                        name = name_mapped
1552
                        break
1553
                else:
1554
1555
1556
1557
1558
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue
1559

1560
1561
1562
1563
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

1564
1565
1566
1567
1568
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

1569
1570
1571
                    if is_pp_missing_parameter(name, self):
                        continue

zhuwenwen's avatar
zhuwenwen committed
1572
1573
1574
1575
                    try:
                        param = params_dict[name]
                    except Exception as e:
                        continue
1576
1577
1578
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
1579
            loaded_params.add(name)
1580

1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "self_attn.q_proj.weight",
                "self_attn.q_a_proj.weight",
                "self_attn.q_b_proj.weight",
                "self_attn.kv_a_proj_with_mqa.weight",
                "self_attn.kv_b_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
                "mlp.gate.weight",
                "shared_experts.gate_up_proj.weight",
                "shared_experts.down_proj.weight",
                "lm_head.weight",
            ]

            combined_words = "|".join(lay_key_words)
1598

1599
1600
1601
1602
1603
1604
            for layername in loaded_params:
                weight = params_dict[layername]
                matches = re.findall(combined_words, layername)
                if matches:
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
1605

1606
1607
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
1608

1609
                    weight.data=weight.data.reshape(ori_shape[1],-1)
1610

1611
        return loaded_params
1612
1613
1614
1615


class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627


def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
                                        weight_name: str) -> Optional[int]:
    if hasattr(config,
               "num_nextn_predict_layers") and (config.num_nextn_predict_layers
                                                > 0):
        layer_idx = config.num_hidden_layers
        for i in range(config.num_nextn_predict_layers):
            if weight_name.startswith(f"model.layers.{layer_idx+i}."):
                return layer_idx + i
    return None