deepseek_v2.py 67.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

wangding zeng's avatar
wangding zeng committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
25
"""Inference-only DeepseekV2/DeepseekV3 model."""
王敏's avatar
王敏 committed
26
27
import os
import re
28
import vllm.envs as envs
zhuwenwen's avatar
zhuwenwen committed
29

30
31
import typing
from collections.abc import Callable, Iterable
32
from itertools import islice
33
from typing import Any, Optional, Union
wangding zeng's avatar
wangding zeng committed
34
35
36

import torch
from torch import nn
37
from transformers import DeepseekV2Config, DeepseekV3Config
wangding zeng's avatar
wangding zeng committed
38

39
from vllm.attention import Attention
40
41
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
42
from vllm.compilation.decorators import support_torch_compile
43
44
from vllm.config import (CacheConfig, ParallelConfig, VllmConfig,
                         get_current_vllm_config)
45
from vllm.distributed import (get_ep_group, get_pp_group,
46
47
48
                              get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_gather)
49
50
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
wangding zeng's avatar
wangding zeng committed
51
from vllm.model_executor.layers.activation import SiluAndMul
52
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
53
from vllm.model_executor.layers.fused_moe import FusedMoE
54
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
wangding zeng's avatar
wangding zeng committed
55
56
57
58
59
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
60
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention
61
from vllm.model_executor.layers.quantization import QuantizationConfig
62
63
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    per_token_group_quant_fp8)
wangding zeng's avatar
wangding zeng committed
64
from vllm.model_executor.layers.rotary_embedding import get_rope
65
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
wangding zeng's avatar
wangding zeng committed
66
67
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
68
69
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
70
from vllm.model_executor.models.utils import sequence_parallel_chunk
71
from vllm.platforms import current_platform
72
from vllm.sequence import IntermediateTensors
73
from vllm.utils import cdiv, direct_register_custom_op
74
75
76
from vllm.v1.attention.backends.mla.indexer import (DeepseekV32IndexerBackend,
                                                    DeepseekV32IndexerMetadata)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
wangding zeng's avatar
wangding zeng committed
77

78
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
79
from .utils import (PPMissingLayer, is_pp_missing_parameter,
80
81
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
王敏's avatar
王敏 committed
82
from vllm import _custom_ops as ops
83
from vllm.utils import W8a8GetCacheJSON
84

85
if current_platform.is_rocm():
86
    import lightop
87
88
89
90
    from lightop import op, gemmopt
else:
    from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits

91

92
93
94
95
96
97
98
if current_platform.is_cuda_alike():
    from vllm import _custom_ops as ops
elif current_platform.is_xpu():
    from vllm._ipex_ops import ipex_ops as ops

logger = init_logger(__name__)

wangding zeng's avatar
wangding zeng committed
99
100
101
102
103
104
105
106
107
108

class DeepseekV2MLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
109
        is_sequence_parallel=False,
110
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
111
112
    ) -> None:
        super().__init__()
113
114
115
116
117

        # If is_sequence_parallel, the input and output tensors are sharded
        # across the ranks within the tp_group. In this case the weights are
        # replicated and no collective ops are needed.
        # Otherwise we use standard TP with an allreduce at the end.
wangding zeng's avatar
wangding zeng committed
118
119
120
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
121
            quant_config=quant_config,
122
            disable_tp=is_sequence_parallel,
123
            prefix=f"{prefix}.gate_up_proj")
wangding zeng's avatar
wangding zeng committed
124
125
126
127
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           quant_config=quant_config,
128
                                           reduce_results=reduce_results,
129
                                           disable_tp=is_sequence_parallel,
130
                                           prefix=f"{prefix}.down_proj")
wangding zeng's avatar
wangding zeng committed
131
132
133
134
135
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

136
137
138
139
140
141
142
    def forward(self, x,
                rms_weight: Optional[torch.Tensor] = None,
                residual: Optional[torch.Tensor] = None,
                update_hd: Optional[bool] = False
                ):
        if envs.USE_FUSED_RMS_QUANT:
            gate_up, new_resi, _  = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
143
144
145
146
147
            if envs.USE_FUSED_SILU_MUL_QUANT:
                x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
            else:
                x = self.act_fn(gate_up)
                x, _ = self.down_proj(x)
148
149
150
151
152
153
            return x, new_resi
        else:
            gate_up, _ = self.gate_up_proj(x)
            x = self.act_fn(gate_up)
            x, _ = self.down_proj(x)
            return x
wangding zeng's avatar
wangding zeng committed
154
155
156
157
158
159


class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
160
        config: Union[DeepseekV2Config, DeepseekV3Config],
161
        parallel_config: ParallelConfig,
wangding zeng's avatar
wangding zeng committed
162
        quant_config: Optional[QuantizationConfig] = None,
163
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
164
165
166
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
167
168
        self.tp_rank = get_tensor_model_parallel_rank()

wangding zeng's avatar
wangding zeng committed
169
        self.routed_scaling_factor = config.routed_scaling_factor
170
171
172
173
174
175

        self.ep_group = get_ep_group().device_group
        self.ep_rank = self.ep_group.rank()
        self.ep_size = self.ep_group.size()
        self.n_routed_experts: int = config.n_routed_experts
        self.n_shared_experts: int = config.n_shared_experts
176

177
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
178

179
180
181
182
        if config.hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {config.hidden_act}. "
                             "Only silu is supported for now.")

wangding zeng's avatar
wangding zeng committed
183
        self.gate = ReplicatedLinear(config.hidden_size,
184
                                     config.n_routed_experts,
wangding zeng's avatar
wangding zeng committed
185
                                     bias=False,
186
187
                                     quant_config=None,
                                     prefix=f"{prefix}.gate")
188
        if config.topk_method == "noaux_tc":
189
190
191
192
193
194
195
            if envs.VLLM_ENABLE_MOE_FUSED_GATE:
                # avoid moe_fused_gate precision error
                self.gate.e_score_correction_bias = nn.Parameter(
                torch.empty(config.n_routed_experts))
            else:
                self.gate.e_score_correction_bias = nn.Parameter(
                    torch.empty(config.n_routed_experts, dtype=torch.float32))
196
197
198
        else:
            self.gate.e_score_correction_bias = None

199
        # Load balancing settings.
200
201
        eplb_config = parallel_config.eplb_config
        self.enable_eplb = parallel_config.enable_eplb
202

203
        self.n_redundant_experts = eplb_config.num_redundant_experts
204
205
206
207
208
209
210
211
212
213
        self.n_logical_experts = self.n_routed_experts
        self.n_physical_experts = (self.n_logical_experts +
                                   self.n_redundant_experts)
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

        self.physical_expert_start = (self.ep_rank *
                                      self.n_local_physical_experts)
        self.physical_expert_end = (self.physical_expert_start +
                                    self.n_local_physical_experts)

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        if config.n_shared_experts is None:
            self.experts = FusedMoE(
                num_experts=config.n_routed_experts,
                top_k=config.num_experts_per_tok,
                hidden_size=config.hidden_size,
                intermediate_size=config.moe_intermediate_size,
                reduce_results=False,
                renormalize=config.norm_topk_prob,
                quant_config=quant_config,
                use_grouped_topk=True,
                num_expert_group=config.n_group,
                topk_group=config.topk_group,
                prefix=f"{prefix}.experts",
                scoring_func=config.scoring_func,
                # we do scaling outside, set factor to 1.0 to avoid double mul
                routed_scaling_factor=1.0,
                e_score_correction_bias=self.gate.e_score_correction_bias,
                enable_eplb=self.enable_eplb,
232
233
234
                num_redundant_experts=self.n_redundant_experts,
                is_sequence_parallel=self.is_sequence_parallel,
            )
235
236
            self.shared_experts = None
        else:
wangding zeng's avatar
wangding zeng committed
237
238
            intermediate_size = (config.moe_intermediate_size *
                                 config.n_shared_experts)
239

wangding zeng's avatar
wangding zeng committed
240
241
242
243
244
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
245
                is_sequence_parallel=self.is_sequence_parallel,
246
                reduce_results=False,
247
                prefix=f"{prefix}.shared_experts",
wangding zeng's avatar
wangding zeng committed
248
            )
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
            self.experts = SharedFusedMoE(
                shared_experts=self.shared_experts,
                num_experts=config.n_routed_experts,
                top_k=config.num_experts_per_tok,
                hidden_size=config.hidden_size,
                intermediate_size=config.moe_intermediate_size,
                reduce_results=False,
                renormalize=config.norm_topk_prob,
                quant_config=quant_config,
                use_grouped_topk=True,
                num_expert_group=config.n_group,
                topk_group=config.topk_group,
                prefix=f"{prefix}.experts",
                scoring_func=config.scoring_func,
                # we do scaling outside, set factor to 1.0 to avoid double mul
                routed_scaling_factor=1.0,
                e_score_correction_bias=self.gate.e_score_correction_bias,
                enable_eplb=self.enable_eplb,
267
268
269
                num_redundant_experts=self.n_redundant_experts,
                is_sequence_parallel=self.is_sequence_parallel,
            )
270

271
272
273
274
    def forward(self, hidden_states: torch.Tensor,
                rms_weight: Optional[torch.Tensor] = None,
                residual: Optional[torch.Tensor] = None
                ) -> torch.Tensor:
wangding zeng's avatar
wangding zeng committed
275
276
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
277
278
279
280
281
282

        # Chunk the hidden states so they aren't replicated across TP ranks.
        # This avoids duplicate computation in self.experts.
        # TODO: We can replace the all_reduce at the end of attn with a
        # reduce_scatter instead of chunking here.
        if self.is_sequence_parallel:
283
            hidden_states = sequence_parallel_chunk(hidden_states)
284

wangding zeng's avatar
wangding zeng committed
285
286
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
287
        
288
289
290
291
        fused_moe_out = self.experts(hidden_states=hidden_states,
                                     router_logits=router_logits)
        if self.shared_experts is not None:
            shared_output, final_hidden_states = fused_moe_out
292
        else:
293
294
            shared_output = None
            final_hidden_states = fused_moe_out
zhuwenwen's avatar
zhuwenwen committed
295
        
296
297
        # Fix FP16 overflow
        # See DeepseekV2DecoderLayer for more details.
zhuwenwen's avatar
zhuwenwen committed
298
        if hidden_states.dtype != torch.float16:
299
300
301
302
303
304
305
306
            final_hidden_states *= self.routed_scaling_factor
        elif self.shared_experts is not None:
            assert shared_output is not None
            shared_output *= (1. / self.routed_scaling_factor)

        if self.shared_experts is not None:
            assert shared_output is not None
            final_hidden_states += shared_output
307

308
309
310
311
312
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
                final_hidden_states, 0)
            final_hidden_states = final_hidden_states[:num_tokens]
        elif self.tp_size > 1:
zhuwenwen's avatar
zhuwenwen committed
313
314
315
            final_hidden_states = (
                self.experts.maybe_all_reduce_tensor_model_parallel(
                    final_hidden_states))
wangding zeng's avatar
wangding zeng committed
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        return final_hidden_states.view(num_tokens, hidden_dim)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


class DeepseekV2Attention(nn.Module):

    def __init__(
        self,
330
        vllm_config: VllmConfig,
331
        config: Union[DeepseekV2Config, DeepseekV3Config],
wangding zeng's avatar
wangding zeng committed
332
333
334
335
336
337
338
339
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
340
        rope_scaling: Optional[dict[str, Any]] = None,
wangding zeng's avatar
wangding zeng committed
341
342
343
        max_position_embeddings: int = 8192,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
344
        topk_indices_buffer: Optional[torch.Tensor] = None,
345
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
362
363
        assert topk_indices_buffer is None, "topk_indices_buffer is not \
        supported for DeepseekV2Attention"
wangding zeng's avatar
wangding zeng committed
364
365
366
367
368

        if self.q_lora_rank is not None:
            self.q_a_proj = ReplicatedLinear(self.hidden_size,
                                             self.q_lora_rank,
                                             bias=False,
369
370
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.q_a_proj")
wangding zeng's avatar
wangding zeng committed
371
372
373
374
375
376
            self.q_a_layernorm = RMSNorm(self.q_lora_rank,
                                         eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(q_lora_rank,
                                                 self.num_heads *
                                                 self.qk_head_dim,
                                                 bias=False,
377
378
                                                 quant_config=quant_config,
                                                 prefix=f"{prefix}.q_b_proj")
wangding zeng's avatar
wangding zeng committed
379
380
381
382
383
        else:
            self.q_proj = ColumnParallelLinear(self.hidden_size,
                                               self.num_heads *
                                               self.qk_head_dim,
                                               bias=False,
384
385
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.q_proj")
wangding zeng's avatar
wangding zeng committed
386

387
388
389
390
391
392
        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_a_proj_with_mqa")
wangding zeng's avatar
wangding zeng committed
393
394
395
396
397
398
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
                                      eps=config.rms_norm_eps)
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
399
400
            quant_config=quant_config,
            prefix=f"{prefix}.kv_b_proj")
wangding zeng's avatar
wangding zeng committed
401
402
403
404
        # O projection.
        self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
                                        self.hidden_size,
                                        bias=False,
405
406
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.o_proj")
407
408
        if rope_scaling:
            rope_scaling["rope_type"] = 'deepseek_yarn'
409

wangding zeng's avatar
wangding zeng committed
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        self.rotary_emb = get_rope(qk_rope_head_dim,
                                   rotary_dim=qk_rope_head_dim,
                                   max_position=max_position_embeddings,
                                   base=rope_theta,
                                   rope_scaling=rope_scaling,
                                   is_neox_style=False)

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

        self.attn = Attention(self.num_local_heads,
424
                              self.qk_head_dim,
wangding zeng's avatar
wangding zeng committed
425
426
427
                              self.scaling,
                              num_kv_heads=self.num_local_heads,
                              cache_config=cache_config,
428
429
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
wangding zeng's avatar
wangding zeng committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
                                         self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
                                                   self.qk_head_dim)
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
                               dim=-1)
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
        kv_a, _ = latent_cache.split(
            [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
450
        kv_a = self.kv_a_layernorm(kv_a)
wangding zeng's avatar
wangding zeng committed
451
452
453
454
455
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads,
                     self.qk_nope_head_dim + self.v_head_dim)
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
        k_pe = latent_cache[:, :, self.kv_lora_rank:]
456

wangding zeng's avatar
wangding zeng committed
457
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
458

wangding zeng's avatar
wangding zeng committed
459
460
461
462
        q[..., self.qk_nope_head_dim:] = q_pe
        k = torch.empty_like(q)
        k[..., :self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim:] = k_pe
463
464
465
466
        # padding value to qk_head_dim for alignment
        v = torch.nn.functional.pad(
            v, [0, self.qk_head_dim - self.v_head_dim],
            value=0).view(-1, self.num_local_heads * self.qk_head_dim)
467
        attn_output = self.attn(q, k, v)
wangding zeng's avatar
wangding zeng committed
468
        attn_output = attn_output.view(
469
470
            -1, self.num_local_heads,
            self.qk_head_dim)[..., :self.v_head_dim].reshape(
wangding zeng's avatar
wangding zeng committed
471
472
473
474
475
                -1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output


476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):

    def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str,
                 cache_config: CacheConfig):
        super().__init__()
        self.kv_cache = [torch.tensor([])]
        self.head_dim = head_dim
        self.prefix = prefix
        self.cache_config = cache_config
        self.dtype = dtype
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

    def get_kv_cache_spec(self) -> KVCacheSpec:
        return MLAAttentionSpec(  # Only has one vector instead of K + V
            block_size=self.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_dim,
            dtype=self.dtype,
        )

    def forward(self):
        ...

    def get_attn_backend(self) -> AttentionBackend:
        return DeepseekV32IndexerBackend


@torch.inference_mode()
def cp_gather_indexer_k_quant_cache(
    kv_cache,  # [num_blocks, block_size, head_dim + 1]
    dst_value,  # [cu_seq_lens[-1], head_dim]
    dst_scale,  # [cu_seq_lens[-1], 4]
    block_table,  # [batch_size, num_blocks]
    cu_seq_lens,  # [batch_size + 1, ]
    batch_size,
):
    num_blocks, block_size, _ = kv_cache.shape
    head_dim = dst_value.shape[-1]
    kv_cache = kv_cache.view(num_blocks, -1)

    expected_value = []
    expected_scale = []
    for b in range(batch_size):
        s = cu_seq_lens[b + 1] - cu_seq_lens[b]
        if s == 0:
            continue
        tot = cdiv(s, block_size)
        blocks = block_table[b, :tot]

        value = []
        scale = []
        full_block = torch.arange(tot - 1,
                                  device=kv_cache.device,
                                  dtype=torch.int32)
        non_remaining_value = kv_cache[blocks[full_block], :block_size *
                                       head_dim].view(-1, head_dim)
        non_remaining_scale = kv_cache[blocks[full_block],
                                       block_size * head_dim:].view(-1, 4)

        remaining = s - (tot - 1) * block_size

        value = torch.cat([
            non_remaining_value,
            kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)
        ],
                          dim=0)
        scale = torch.cat([
            non_remaining_scale,
            kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
                     remaining * 4].view(-1, 4)
        ],
                          dim=0)

        expected_value.append(value)
        expected_scale.append(scale)

    gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
    gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
    gather_value = gather_value.view(torch.float8_e4m3fn)
    gather_scale = gather_scale.view(torch.float32)
    dst_value.copy_(gather_value)
    dst_scale.copy_(gather_scale)


def sparse_attn_indexer(
    hidden_states: torch.Tensor,
    k_cache_prefix: str,
    kv_cache: torch.Tensor,
    q_fp8: torch.Tensor,
    k: torch.Tensor,
    weights: torch.Tensor,
    quant_block_size: int,
    scale_fmt: Optional[str],
    topk_tokens: int,
    head_dim: int,
    max_model_len: int,
    total_seq_lens: int,
    topk_indices_buffer: Optional[torch.Tensor],
) -> torch.Tensor:

    # careful! this will be None in dummy run
    attn_metadata = get_forward_context().attn_metadata
    # assert isinstance(attn_metadata, dict)
    if not isinstance(attn_metadata, dict):
        return sparse_attn_indexer_fake(
            hidden_states,
            k_cache_prefix,
            kv_cache,
            q_fp8,
            k,
            weights,
            quant_block_size,
            scale_fmt,
            topk_tokens,
            head_dim,
            max_model_len,
            total_seq_lens,
            topk_indices_buffer,
        )
    attn_metadata = attn_metadata[k_cache_prefix]
    assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
    slot_mapping = attn_metadata.slot_mapping
    has_decode = attn_metadata.num_decodes > 0
    has_prefill = attn_metadata.num_prefills > 0
    num_decode_tokens = attn_metadata.num_decode_tokens

605
606
607
608
609
610
611
612
    if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
        ops.indexer_k_quant_and_cache(
            k,
            kv_cache,
            slot_mapping,
            quant_block_size,
            scale_fmt,
        )
613
614
615
616

    topk_indices_buffer[:hidden_states.shape[0]] = -1
    if has_prefill:
        prefill_metadata = attn_metadata.prefill
617
        for chunk in prefill_metadata.chunks:
618
619
            if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
                k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
620
621
                                device=k.device,
                                dtype=torch.float8_e4m3fn)
622
623
624
625
626
627
628
629
630
631
632
633
                k_scale = torch.empty([chunk.total_seq_lens, 1],
                                    device=k.device,
                                    dtype=torch.float32)
                cp_gather_indexer_k_quant_cache(
                    kv_cache,
                    k_fp8,
                    k_scale,
                    chunk.block_table,
                    chunk.cu_seq_lens,
                    chunk.num_reqs,
                )
            
634
635
636
637
638
639
640
641
642
                logits = fp8_mqa_logits(
                    q_fp8[chunk.token_start:chunk.token_end],
                    (k_fp8, k_scale),
                    weights[chunk.token_start:chunk.token_end],
                    chunk.cu_seqlen_ks,
                    chunk.cu_seqlen_ke,
                )
            else:
                logits = op.mqa_logits(
643
644
645
646
647
648
649
650
651
652
653
654
                    q_fp8[chunk.token_start:chunk.token_end],  
                    k, 
                    weights[chunk.token_start:chunk.token_end] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[chunk.token_start:chunk.token_end].to(torch.float32), 
                    chunk.cu_seqlen_ks, 
                    chunk.cu_seqlen_ke, 
                    q_fp8[chunk.token_start:chunk.token_end].shape[0], 
                    k.shape[0],
                    64,
                    128,
                    True,
                )
                    
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
            topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
                                       dim=-1)[1]
            topk_indices -= chunk.cu_seqlen_ks[:, None]
            mask_lo = topk_indices >= 0
            mask_hi = topk_indices - (chunk.cu_seqlen_ke -
                                      chunk.cu_seqlen_ks)[:, None] < 0
            mask = torch.full_like(topk_indices,
                                   False,
                                   dtype=torch.bool,
                                   device=topk_indices.device)
            mask = mask_lo & mask_hi
            topk_indices = topk_indices.masked_fill(~mask, -1)
            topk_indices_buffer[
                chunk.token_start:chunk.token_end, :topk_indices.
                shape[-1]] = topk_indices.to(dtype=torch.int32)
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691

    if has_decode:
        decode_metadata = attn_metadata.decode
        # kv_cache size requirement [num_block, block_size, n_head, head_dim],
        # we only have [num_block, block_size, head_dim],
        kv_cache = kv_cache.unsqueeze(-2)
        decode_lens = decode_metadata.decode_lens
        if decode_metadata.requires_padding:
            # pad in edge case where we have short chunked prefill length <
            # decode_threshold since we unstrictly split
            # prefill and decode by decode_threshold
            # (currently set to 1 + speculative tokens)
            padded_q_fp8_decode_tokens = pack_seq_triton(
                q_fp8[:num_decode_tokens], decode_lens)
        else:
            padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
                decode_lens.shape[0], -1, *q_fp8.shape[1:])
        # TODO: move and optimize below logic with triton kernels
        batch_size = padded_q_fp8_decode_tokens.shape[0]
        next_n = padded_q_fp8_decode_tokens.shape[1]
        assert batch_size == decode_metadata.seq_lens.shape[0]
        num_padded_tokens = batch_size * next_n
692
        if not current_platform.is_rocm():
693
694
695
696
697
698
699
700
701
702
703
            logits = fp8_paged_mqa_logits(
                padded_q_fp8_decode_tokens,
                kv_cache,
                weights[:num_padded_tokens],
                decode_metadata.seq_lens,
                decode_metadata.block_table,
                decode_metadata.schedule_metadata,
                max_model_len=max_model_len,
            )
        else:
            logits = gemmopt.paged_mqa_logits(
704
705
706
707
708
709
710
                padded_q_fp8_decode_tokens, 
                kv_cache if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else kv_cache.to(torch.bfloat16), 
                weights[:num_padded_tokens] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[:num_padded_tokens].to(torch.float32), 
                decode_metadata.seq_lens, 
                decode_metadata.block_table, 
                decode_metadata.schedule_metadata, 
                max_model_len,
711
            )
712
                
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
        # padded query len
        current_device = padded_q_fp8_decode_tokens.device
        padded_num_tokens = batch_size * next_n
        positions = torch.arange(max_model_len,
                                 device=current_device).unsqueeze(0).expand(
                                     batch_size * next_n, -1)
        row_indices = torch.arange(padded_num_tokens,
                                   device=current_device) // next_n
        next_n_offset = torch.arange(
            padded_num_tokens,
            device=padded_q_fp8_decode_tokens.device) % next_n
        index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
                         next_n_offset).unsqueeze(1)
        # index_end_pos: [B * N, 1]
        mask = positions <= index_end_pos
        # mask: [B * N, L]
        logits = logits.masked_fill(~mask, float('-inf'))
        topk_indices = logits.topk(topk_tokens,
                                   dim=-1)[1].to(torch.int32)  # [B * N, K]
        # ensure we don't set indices for the top k
        # that is out of range(masked already)
        # this will happen if context length is shorter than K
        topk_indices[topk_indices > index_end_pos] = -1
        if decode_metadata.requires_padding:
            # if padded, we need to unpack
            # the topk indices removing padded tokens
            topk_indices = unpack_seq_triton(
                topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
                decode_lens)
        topk_indices_buffer[:num_decode_tokens, :topk_indices.
                            shape[-1]] = topk_indices.to(dtype=torch.int32)

    return topk_indices_buffer


def sparse_attn_indexer_fake(
    hidden_states: torch.Tensor,
    k_cache_prefix: str,
    kv_cache: torch.Tensor,
    q_fp8: torch.Tensor,
    k: torch.Tensor,
    weights: torch.Tensor,
    quant_block_size: int,
    scale_fmt: Optional[str],
    topk_tokens: int,
    head_dim: int,
    max_model_len: int,
    total_seq_lens: int,
    topk_indices_buffer: Optional[torch.Tensor],
) -> torch.Tensor:
    # profile run
    # NOTE(Chen): create the max possible flattened_kv. So that
    # profile_run can get correct memory usage.
766
767
768
769
770
771
772
    if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
        _flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
                                    device=k.device,
                                    dtype=torch.uint8)
        _k_fp8 = _flattened_kv[..., :head_dim].view(
            torch.float8_e4m3fn).contiguous()
        _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
    return topk_indices_buffer


direct_register_custom_op(
    op_name="sparse_attn_indexer",
    op_func=sparse_attn_indexer,
    mutates_args=["topk_indices_buffer"],
    fake_impl=sparse_attn_indexer_fake,
    dispatch_key=current_platform.dispatch_key,
)


class Indexer(nn.Module):

    def __init__(self,
                 vllm_config: VllmConfig,
                 config: Union[DeepseekV2Config, DeepseekV3Config],
                 hidden_size: int,
                 q_lora_rank: int,
                 quant_config: Optional[QuantizationConfig],
                 cache_config: Optional[CacheConfig],
                 topk_indices_buffer: Optional[torch.Tensor],
                 prefix: str = ""):
        super().__init__()
        self.vllm_config = vllm_config
        self.config = config
        # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
        self.topk_tokens = config.index_topk
        self.n_head = config.index_n_heads  # 64
        self.head_dim = config.index_head_dim  # 128
        self.rope_dim = config.qk_rope_head_dim  # 64
        self.q_lora_rank = q_lora_rank  # 1536
        # no tensor parallel, just replicated
        self.wq_b = ReplicatedLinear(self.q_lora_rank,
                                     self.head_dim * self.n_head,
                                     bias=False,
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.wq_b")
        self.wk = ReplicatedLinear(hidden_size,
                                   self.head_dim,
                                   bias=False,
                                   quant_config=quant_config,
                                   prefix=f"{prefix}.wk")
        self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
        self.weights_proj = ReplicatedLinear(hidden_size,
                                             self.n_head,
                                             quant_config=None,
                                             prefix=f"{prefix}.weights_proj")
        self.softmax_scale = self.head_dim**-0.5

        self.scale_fmt = "ue8m0"
        self.quant_block_size = 128  # TODO: get from config
        self.topk_indices_buffer = topk_indices_buffer

        # NOTE: (zyongye) we use fp8 naive cache,
        #       where we store value in fp8 and scale in fp32
        #       per self.quant_block_size element
        self.k_cache = DeepseekV32IndexerCache(
            head_dim=self.head_dim +
            self.head_dim // self.quant_block_size * 4,
            dtype=torch.uint8,
            prefix=f"{prefix}.k_cache",
            cache_config=cache_config)
        self.max_model_len = vllm_config.model_config.max_model_len
        self.prefix = prefix
        from vllm.v1.attention.backends.mla.indexer import (
            get_max_prefill_buffer_size)
        self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)

    def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions,
                rotary_emb) -> torch.Tensor:
        q, _ = self.wq_b(qr)
        q = q.view(-1, self.n_head, self.head_dim)
        q_pe, q_nope = torch.split(
            q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)

        k, _ = self.wk(hidden_states)
        k = self.k_norm(k)
        k_pe, k_nope = torch.split(
            k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)

        q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
        q = torch.cat([q_pe, q_nope], dim=-1)
        k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)

        # we only quant q here since k quant is fused with cache insertion
859
860
861
862
863
864
865
866
867
        if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
            q = q.view(-1, self.head_dim)
            q_fp8, q_scale = per_token_group_quant_fp8(q,
                                                    self.quant_block_size,
                                                    column_major_scales=False,
                                                    use_ue8m0=self.scale_fmt
                                                    is not None)
            q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
            q_scale = q_scale.view(-1, self.n_head, 1)
868
869

        weights, _ = self.weights_proj(hidden_states)
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
        if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
            weights = weights.unsqueeze(
                -1) * q_scale * self.softmax_scale * self.n_head**-0.5
            weights = weights.squeeze(-1)

        if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
            return torch.ops.vllm.sparse_attn_indexer(
                hidden_states,
                self.k_cache.prefix,
                self.k_cache.kv_cache[0],
                q_fp8,
                k,
                weights,
                self.quant_block_size,
                self.scale_fmt,
                self.topk_tokens,
                self.head_dim,
                self.max_model_len,
                self.max_total_seq_len,
                self.topk_indices_buffer,
            )
        else:
            return torch.ops.vllm.sparse_attn_indexer(
                hidden_states,
                self.k_cache.prefix,
                self.k_cache.kv_cache[0],
                q,
                k,
                weights,
                self.quant_block_size,
                self.scale_fmt,
                self.topk_tokens,
                self.head_dim,
                self.max_model_len,
                self.max_total_seq_len,
                self.topk_indices_buffer,
            )
907
908


909
910
911
912
913
class DeepseekV2MLAAttention(nn.Module):
    """
    Main reference: DeepseekV2 paper, and FlashInfer Implementation
    (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
    
914
915
        For more info see MLACommonImpl in:
        vllm/v1/attention/backends/mla/utils.py
916
917
918
919
    """

    def __init__(
        self,
920
        vllm_config: VllmConfig,
921
        config: Union[DeepseekV2Config, DeepseekV3Config],
922
923
924
925
926
927
928
929
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: Optional[int],
        kv_lora_rank: int,
        rope_theta: float = 10000,
930
        rope_scaling: Optional[dict[str, Any]] = None,
931
932
933
934
        max_position_embeddings: int = 8192,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
935
        topk_indices_buffer: Optional[torch.Tensor] = None,
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank

        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size

        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
957
            self.fused_qkv_a_proj = MergedColumnParallelLinear(
958
959
960
961
                self.hidden_size,
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
                bias=False,
                quant_config=quant_config,
962
963
                prefix=f"{prefix}.fused_qkv_a_proj",
                disable_tp=True)
964
965
966
967
968
969
970
971
972
        else:
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.kv_a_proj_with_mqa")

        if self.q_lora_rank is not None:
973
974
            self.q_a_layernorm = RMSNorm(self.q_lora_rank,
                                         eps=config.rms_norm_eps)
975
            self.q_b_proj = ColumnParallelLinear(self.q_lora_rank,
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
                                                 self.num_heads *
                                                 self.qk_head_dim,
                                                 bias=False,
                                                 quant_config=quant_config,
                                                 prefix=f"{prefix}.q_b_proj")
        else:
            self.q_proj = ColumnParallelLinear(self.hidden_size,
                                               self.num_heads *
                                               self.qk_head_dim,
                                               bias=False,
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.q_proj")
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
                                      eps=config.rms_norm_eps)
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_b_proj")
        self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
                                        self.hidden_size,
                                        bias=False,
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.o_proj")

1002
1003
        if rope_scaling:
            rope_scaling["rope_type"] = 'deepseek_yarn'
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
        self.rotary_emb = get_rope(qk_rope_head_dim,
                                   rotary_dim=qk_rope_head_dim,
                                   max_position=max_position_embeddings,
                                   base=rope_theta,
                                   rope_scaling=rope_scaling,
                                   is_neox_style=False)
        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

zhuwenwen's avatar
zhuwenwen committed
1016
        self.is_v32 = hasattr(config, "index_topk")
1017
1018
1019
1020
1021
1022
1023
1024

        if self.is_v32:
            self.indexer = Indexer(vllm_config, config, hidden_size,
                                   q_lora_rank, quant_config, cache_config,
                                   topk_indices_buffer, f"{prefix}.indexer")
        else:
            self.indexer = None

1025
1026
        mla_modules = MLAModules(
            kv_a_layernorm=self.kv_a_layernorm,
1027
            kv_b_proj=self.kv_b_proj,
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
            rotary_emb=self.rotary_emb,
            o_proj=self.o_proj,
            fused_qkv_a_proj=self.fused_qkv_a_proj
            if self.q_lora_rank is not None else None,
            kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
            if self.q_lora_rank is None else None,
            q_a_layernorm=self.q_a_layernorm
            if self.q_lora_rank is not None else None,
            q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
            q_proj=self.q_proj if self.q_lora_rank is None else None,
1038
1039
1040
            indexer=self.indexer,
            is_sparse=self.is_v32,
            topk_indices_buffer=topk_indices_buffer,
1041
        )
1042

1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
        self.mla_attn = MultiHeadLatentAttention(
            self.hidden_size,
            self.num_local_heads,
            self.scaling,
            self.qk_nope_head_dim,
            self.qk_rope_head_dim,
            self.v_head_dim,
            self.q_lora_rank,
            self.kv_lora_rank,
            mla_modules,
            cache_config,
            quant_config,
            prefix,
1056
1057
1058
1059
1060
1061
1062
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
1063
        return self.mla_attn(positions, hidden_states)
1064
1065


wangding zeng's avatar
wangding zeng committed
1066
1067
class DeepseekV2DecoderLayer(nn.Module):

1068
1069
1070
1071
    def __init__(self,
                 vllm_config: VllmConfig,
                 prefix: str,
                 topk_indices_buffer: Optional[torch.Tensor] = None) -> None:
wangding zeng's avatar
wangding zeng committed
1072
        super().__init__()
1073
1074
1075
1076
1077
1078
1079

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

wangding zeng's avatar
wangding zeng committed
1080
1081
1082
1083
1084
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
1085
1086
1087
        # DecoderLayers are created with `make_layers` which passes the prefix
        # with the layer's index.
        layer_idx = int(prefix.split(sep='.')[-1])
1088
        self.layer_idx = layer_idx
1089
1090
1091
1092
1093
        if model_config.use_mla:
            attn_cls = DeepseekV2MLAAttention
        else:
            attn_cls = DeepseekV2Attention
        self.self_attn = attn_cls(
1094
            vllm_config=vllm_config,
wangding zeng's avatar
wangding zeng committed
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=config.q_lora_rank
            if hasattr(config, "q_lora_rank") else None,
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
1109
            prefix=f"{prefix}.self_attn",
1110
            topk_indices_buffer=topk_indices_buffer,
wangding zeng's avatar
wangding zeng committed
1111
1112
1113
1114
        )
        if (config.n_routed_experts is not None
                and layer_idx >= config.first_k_dense_replace
                and layer_idx % config.moe_layer_freq == 0):
1115
1116
            self.mlp = DeepseekV2MoE(
                config=config,
1117
                parallel_config=parallel_config,
1118
1119
1120
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
wangding zeng's avatar
wangding zeng committed
1121
1122
1123
1124
1125
1126
        else:
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1127
                prefix=f"{prefix}.mlp",
wangding zeng's avatar
wangding zeng committed
1128
1129
1130
1131
1132
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
1133
        self.routed_scaling_factor = config.routed_scaling_factor
wangding zeng's avatar
wangding zeng committed
1134
1135
1136
1137
1138
1139
1140

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
        if envs.USE_FUSED_RMS_QUANT:
            # Fix residual FP16 overflow
            residual_fix_overflow = False
            
            assert self.input_layernorm.has_weight is True
            if residual is None:
                residual = hidden_states
                hidden_states, _ = self.self_attn(
                    positions = positions,
                    hidden_states = hidden_states,
                    rms_weight = self.input_layernorm.weight.data,
                    residual = None
                )
                residual_fix_overflow = True
            else:
                hidden_states, new_residual = self.self_attn(
                    positions = positions,
                    hidden_states = hidden_states,
                    rms_weight = self.input_layernorm.weight.data,
                    residual = residual
                )
                residual = new_residual
                
zhuwenwen's avatar
zhuwenwen committed
1164
            if hidden_states.dtype == torch.float16:
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
                # rmsnorm, and rmsnorm result would not affect by scale.
                hidden_states *= 1. / self.routed_scaling_factor
                if self.layer_idx == 0 or residual_fix_overflow:
                    # The residual is shared by all layers, we only scale it on
                    # first layer.
                    residual *= 1. / self.routed_scaling_factor

            hidden_states, new_resi = self.mlp(hidden_states, self.post_attention_layernorm.weight.data, residual)

            if isinstance(self.mlp,
zhuwenwen's avatar
zhuwenwen committed
1175
                        DeepseekV2MLP) and hidden_states.dtype == torch.float16:
1176
1177
1178
1179
1180
1181
1182
                # Fix FP16 overflow
                # Scaling the DeepseekV2MLP output, it is the input of
                # input_layernorm of next decoder layer.
                # The scaling of DeepseekV2MOE output would be done in the forward
                # of DeepseekV2MOE
                hidden_states *= 1. / self.routed_scaling_factor
            return hidden_states, new_resi
wangding zeng's avatar
wangding zeng committed
1183
        else:
1184
1185
1186
1187
            # Self Attention
            # Fix residual FP16 overflow
            residual_fix_overflow = False
            if residual is None:
1188
                residual = hidden_states.clone()
1189
1190
1191
1192
1193
1194
1195
1196
1197
                hidden_states = self.input_layernorm(hidden_states)
                residual_fix_overflow = True
            else:
                hidden_states, residual = self.input_layernorm(
                    hidden_states, residual)
            hidden_states = self.self_attn(
                positions=positions,
                hidden_states=hidden_states,
            )
wangding zeng's avatar
wangding zeng committed
1198

zhuwenwen's avatar
zhuwenwen committed
1199
            if hidden_states.dtype == torch.float16:
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
                # Fix FP16 overflow
                # We scale both hidden_states and residual before
                # rmsnorm, and rmsnorm result would not affect by scale.
                hidden_states *= 1. / self.routed_scaling_factor
                if self.layer_idx == 0 or residual_fix_overflow:
                    # The residual is shared by all layers, we only scale it on
                    # first layer.
                    residual *= 1. / self.routed_scaling_factor

            # Fully Connected
            hidden_states, residual = self.post_attention_layernorm(
                hidden_states, residual)
            hidden_states = self.mlp(hidden_states)
            if isinstance(self.mlp,
zhuwenwen's avatar
zhuwenwen committed
1214
                        DeepseekV2MLP) and hidden_states.dtype == torch.float16:
1215
1216
1217
1218
1219
1220
1221
                # Fix FP16 overflow
                # Scaling the DeepseekV2MLP output, it is the input of
                # input_layernorm of next decoder layer.
                # The scaling of DeepseekV2MOE output would be done in the forward
                # of DeepseekV2MOE
                hidden_states *= 1. / self.routed_scaling_factor
            return hidden_states, residual
wangding zeng's avatar
wangding zeng committed
1222
1223


1224
@support_torch_compile
wangding zeng's avatar
wangding zeng committed
1225
1226
1227
1228
class DeepseekV2Model(nn.Module):

    fall_back_to_pt_during_load = False

1229
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
wangding zeng's avatar
wangding zeng committed
1230
        super().__init__()
1231
1232
1233

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
1234
        self.config = config
1235

wangding zeng's avatar
wangding zeng committed
1236
        self.vocab_size = config.vocab_size
zhuwenwen's avatar
zhuwenwen committed
1237
        self.is_v32 = hasattr(config, "index_topk")
1238
1239
1240
1241
1242
1243
1244
1245
1246
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
                device="cuda")
        else:
            topk_indices_buffer = None
wangding zeng's avatar
wangding zeng committed
1247

1248
1249
1250
1251
        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
1252
1253
                quant_config=quant_config,
                prefix=f"{prefix}.embed_tokens")
1254
1255
1256
1257
1258
        else:
            self.embed_tokens = PPMissingLayer()

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
1259
1260
            lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
                                                  topk_indices_buffer),
1261
1262
1263
1264
1265
1266
            prefix=f"{prefix}.layers")

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
1267
1268
1269
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
wangding zeng's avatar
wangding zeng committed
1270

1271
1272
1273
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

wangding zeng's avatar
wangding zeng committed
1274
1275
1276
1277
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1278
        intermediate_tensors: Optional[IntermediateTensors],
1279
        inputs_embeds: Optional[torch.Tensor] = None,
1280
    ) -> Union[torch.Tensor, IntermediateTensors]:
1281
        if get_pp_group().is_first_rank:
1282
1283
1284
1285
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
1286
1287
1288
1289
1290
1291
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1292
        for layer in islice(self.layers, self.start_layer, self.end_layer):
1293
            hidden_states, residual = layer(positions, hidden_states, residual)
1294
1295
1296
1297
1298
1299
1300

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

wangding zeng's avatar
wangding zeng committed
1301
1302
1303
1304
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


1305
1306
class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts,
                            SupportsLoRA):
1307
1308
1309
    packed_modules_mapping = {
        "gate_up_proj": ["gate_proj", "up_proj"],
    }
wangding zeng's avatar
wangding zeng committed
1310

1311
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
wangding zeng's avatar
wangding zeng committed
1312
        super().__init__()
1313
1314
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
1315
1316
1317
1318
1319
1320
1321

        self.quant_method = None
        if quant_config is not None:
            self.quant_method = quant_config.get_name()
            os.environ['LLAMA_NN'] = '0'
            os.environ['LM_NN'] = '0'

1322
        self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1'
wangding zeng's avatar
wangding zeng committed
1323
1324
        self.config = config
        self.quant_config = quant_config
王敏's avatar
王敏 committed
1325

1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
        # `packed_modules_mapping` needs to be modified before
        # initializing DeepseekV2Model, as it is passed inplace to
        # quantization config init and may be used to select the
        # quant_method for relevant layers during initialization.
        self.fuse_qkv_a_proj = hasattr(
            config, "q_lora_rank") and config.q_lora_rank is not None
        if self.fuse_qkv_a_proj:
            self.packed_modules_mapping["fused_qkv_a_proj"] = [
                "q_a_proj",
                "kv_a_proj_with_mqa",
            ]

1338

1339
        self.model = DeepseekV2Model(vllm_config=vllm_config,
1340
                                     prefix=maybe_prefix(prefix, "model"))
1341
        if get_pp_group().is_last_rank:
1342
1343
1344
1345
1346
1347
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
1348
1349
        else:
            self.lm_head = PPMissingLayer()
wangding zeng's avatar
wangding zeng committed
1350
        self.logits_processor = LogitsProcessor(config.vocab_size)
1351
1352
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
1353
1354
1355
1356
1357
1358
1359
1360
        self.expert_weights = []

        # Set MoE hyperparameters
        self.num_moe_layers = (config.num_hidden_layers -
                               config.first_k_dense_replace)
        self.num_expert_groups = config.n_group

        self.moe_layers: list[FusedMoE] = []
1361
        example_moe = None
1362
        for layer in self.model.layers:
1363
1364
1365
            if isinstance(layer, PPMissingLayer):
                continue

1366
1367
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
1368
1369
                # Pick last one layer since the first ones may be dense layers.
                example_moe = layer.mlp
1370
1371
                self.moe_layers.append(layer.mlp.experts)

1372
1373
1374
        if example_moe is None:
            raise RuntimeError("No DeepseekV2MoE layer found in model.layers.")

1375
1376
1377
1378
1379
1380
        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_shared_experts = example_moe.n_shared_experts
        self.num_redundant_experts = example_moe.n_redundant_experts
zhuwenwen's avatar
zhuwenwen committed
1381
        
王敏's avatar
王敏 committed
1382
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
1383
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
1384
1385
1386
        self.tritonsingleton= W8a8GetCacheJSON() 
        self.tritonsingleton.topk = config.num_experts_per_tok
        self.tritonsingleton.quant_method=self.quant_method 
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402

    def set_eplb_state(
        self,
        expert_load_view: torch.Tensor,
        logical_to_physical_map: torch.Tensor,
        logical_replica_count: torch.Tensor,
    ) -> None:
        for layer_idx, layer in enumerate(self.moe_layers):
            # Register the expert weights.
            self.expert_weights.append(layer.get_expert_weights())
            layer.set_eplb_state(
                moe_layer_idx=layer_idx,
                expert_load_view=expert_load_view,
                logical_to_physical_map=logical_to_physical_map,
                logical_replica_count=logical_replica_count,
            )
1403

1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = (num_physical_experts -
                                      self.num_logical_experts)
        for layer in self.model.layers:
            if isinstance(layer.mlp, DeepseekV2MoE):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

1422
1423
1424
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

wangding zeng's avatar
wangding zeng committed
1425
1426
1427
1428
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1429
        intermediate_tensors: Optional[IntermediateTensors] = None,
1430
        inputs_embeds: Optional[torch.Tensor] = None,
1431
    ) -> Union[torch.Tensor, IntermediateTensors]:
1432
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
1433
                                   inputs_embeds)
wangding zeng's avatar
wangding zeng committed
1434
1435
        return hidden_states

1436
1437
1438
1439
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
1440
        logits = self.logits_processor(self.lm_head, hidden_states)
wangding zeng's avatar
wangding zeng committed
1441
1442
        return logits

1443

1444
1445
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
wangding zeng's avatar
wangding zeng committed
1446
1447
1448
1449
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
1450
1451
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
wangding zeng's avatar
wangding zeng committed
1452
1453
        ]

1454
1455
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
王敏's avatar
王敏 committed
1456
1457
1458
1459
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1460
1461
            num_experts=self.config.n_routed_experts,
            num_redundant_experts=self.num_redundant_experts)
1462

wangding zeng's avatar
wangding zeng committed
1463
        params_dict = dict(self.named_parameters())
1464
        loaded_params: set[str] = set()
wangding zeng's avatar
wangding zeng committed
1465
1466
1467
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
1468

1469
1470
1471
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is not None:
                continue  # skip spec decode layers for main model
1472

wangding zeng's avatar
wangding zeng committed
1473
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
1474
                # Skip non-stacked layers and experts (experts handled below).
wangding zeng's avatar
wangding zeng committed
1475
1476
                if weight_name not in name:
                    continue
1477
1478
1479
1480
1481
1482
1483
1484
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if (("mlp.experts." in name) and name not in params_dict):
                    continue
1485
                name_mapped = name.replace(weight_name, param_name)
1486
1487
1488

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
1489
                # if go with fusion option, then update name
1490
                if ((param_name == "fused_qkv_a_proj")
1491
                        and name_mapped not in params_dict):
1492
                    continue
1493
1494
                else:
                    name = name_mapped
wangding zeng's avatar
wangding zeng committed
1495
1496
1497
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1498
1499
1500
1501

                if is_pp_missing_parameter(name, self):
                    continue

wangding zeng's avatar
wangding zeng committed
1502
1503
1504
1505
1506
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
1507
                is_expert_weight = False
1508
1509
1510
1511
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
1512

1513
1514
1515
1516
1517
1518
1519
1520
1521
                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name_mapped, self):
1522
1523
                        continue

1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
                    weight_loader = typing.cast(Callable[..., bool],
                                                param.weight_loader)
                    success = weight_loader(param,
                                            loaded_weight,
                                            name_mapped,
                                            shard_id=shard_id,
                                            expert_id=expert_id,
                                            return_success=True)
                    if success:
1537
                        name = name_mapped
1538
                        break
1539
                else:
1540
1541
1542
1543
1544
1545
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue

1546
1547
1548
1549
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

1550
1551
1552
1553
1554
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

1555
1556
1557
                    if is_pp_missing_parameter(name, self):
                        continue

1558
1559
1560
1561
1562
                    try:
                        param = params_dict[name]
                    except Exception as e:
                        continue

1563
1564
1565
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
1566
            loaded_params.add(name)
王敏's avatar
王敏 committed
1567
            
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "self_attn.q_proj.weight",
                "self_attn.q_a_proj.weight",
                "self_attn.q_b_proj.weight",
                "self_attn.kv_a_proj_with_mqa.weight",
                "self_attn.kv_b_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
                "mlp.gate.weight",
                "shared_experts.gate_up_proj.weight",
                "shared_experts.down_proj.weight",
                "lm_head.weight",
            ]

            combined_words = "|".join(lay_key_words)
            
            for layername in loaded_params:
                weight = params_dict[layername]
                matches = re.findall(combined_words, layername)
                if matches:
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
zhuwenwen's avatar
zhuwenwen committed
1597
            
1598
        return loaded_params
1599
1600
1601
1602


class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass
1603
1604


1605
1606
1607
1608
# Compatibility with
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py
def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config,
                                                      DeepseekV3Config],
1609
                                        weight_name: str) -> Optional[int]:
1610
1611
    if (hasattr(config, "num_nextn_predict_layers")
            and config.num_nextn_predict_layers > 0):
1612
1613
1614
1615
        layer_idx = config.num_hidden_layers
        for i in range(config.num_nextn_predict_layers):
            if weight_name.startswith(f"model.layers.{layer_idx+i}."):
                return layer_idx + i
1616
    return None