deepseek_v2.py 61.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Liangsheng Yin's avatar
Liangsheng Yin committed
15
16
17
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
18

19
import logging
20
import os
Liangsheng Yin's avatar
Liangsheng Yin committed
21
22
23
from typing import Any, Dict, Iterable, Optional, Tuple

import torch
Ke Bao's avatar
Ke Bao committed
24
import torch.nn.functional as F
Liangsheng Yin's avatar
Liangsheng Yin committed
25
from torch import nn
26
from tqdm import tqdm
Liangsheng Yin's avatar
Liangsheng Yin committed
27
from transformers import PretrainedConfig
28
29

from sglang.srt.distributed import (
Liangsheng Yin's avatar
Liangsheng Yin committed
30
    get_tensor_model_parallel_world_size,
31
    parallel_state,
Liangsheng Yin's avatar
Liangsheng Yin committed
32
33
    tensor_model_parallel_all_reduce,
)
34
from sglang.srt.layers.activation import SiluAndMul
Lianmin Zheng's avatar
Lianmin Zheng committed
35
from sglang.srt.layers.dp_attention import (
36
    dp_gather_partial,
Lianmin Zheng's avatar
Lianmin Zheng committed
37
38
39
40
    dp_scatter,
    get_attention_dp_size,
    get_attention_tp_rank,
    get_attention_tp_size,
41
42
    tp_all_gather,
    tp_reduce_scatter,
Lianmin Zheng's avatar
Lianmin Zheng committed
43
)
44
from sglang.srt.layers.layernorm import RMSNorm
45
46
47
48
49
50
from sglang.srt.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
51
from sglang.srt.layers.logits_processor import LogitsProcessor
52
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
Ke Bao's avatar
Ke Bao committed
53
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
54
from sglang.srt.layers.moe.topk import select_experts
55
from sglang.srt.layers.quantization.base_config import QuantizationConfig
HandH1998's avatar
HandH1998 committed
56
57
58
from sglang.srt.layers.quantization.fp8_utils import (
    block_quant_to_tensor_quant,
    input_to_float8,
59
    normalize_e4m3fn_to_e4m3fnuz,
HandH1998's avatar
HandH1998 committed
60
)
61
62
63
from sglang.srt.layers.quantization.int8_utils import (
    block_dequant as int8_block_dequant,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
64
from sglang.srt.layers.radix_attention import RadixAttention
65
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
66
67
68
69
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
70
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
71
from sglang.srt.managers.schedule_batch import global_server_args_dict
72
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
73
from sglang.srt.model_loader.weight_utils import default_weight_loader
74
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
75

76
_is_hip = is_hip()
Yineng Zhang's avatar
Yineng Zhang committed
77
_is_cuda = is_cuda()
78

Yineng Zhang's avatar
Yineng Zhang committed
79
80
if _is_cuda:
    from sgl_kernel import awq_dequantize, bmm_fp8
81
82

    from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
Yineng Zhang's avatar
Yineng Zhang committed
83
84
else:
    from vllm import _custom_ops as ops
Liangsheng Yin's avatar
Liangsheng Yin committed
85

86
87
88
89
90
if _is_hip:
    from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
        decode_attention_fwd_grouped_rope,
    )

91
92
expert_distribution_recorder = ExpertDistributionRecorder()

93
94
logger = logging.getLogger(__name__)

Liangsheng Yin's avatar
Liangsheng Yin committed
95
96
97
98
99
100
101
102
103

class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
104
        prefix: str = "",
105
106
        tp_rank: Optional[int] = None,
        tp_size: Optional[int] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
107
108
109
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
110
111
112
113
114
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("gate_up_proj", prefix),
115
116
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
117
118
119
120
121
122
123
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
124
            prefix=add_prefix("down_proj", prefix),
125
126
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


Ke Bao's avatar
Ke Bao committed
142
class MoEGate(nn.Module):
143
144
145
146
147
    def __init__(
        self,
        config,
        prefix: str = "",
    ):
Ke Bao's avatar
Ke Bao committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        super().__init__()
        self.weight = nn.Parameter(
            torch.empty((config.n_routed_experts, config.hidden_size))
        )
        if config.topk_method == "noaux_tc":
            self.e_score_correction_bias = nn.Parameter(
                torch.empty((config.n_routed_experts))
            )
        else:
            self.e_score_correction_bias = None

    def forward(self, hidden_states):
        logits = F.linear(hidden_states, self.weight, None)
        return logits


Liangsheng Yin's avatar
Liangsheng Yin committed
164
165
166
167
168
169
class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
170
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
171
172
173
174
175
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
        self.n_shared_experts = config.n_shared_experts
176
177
178
179
180
181
        self.n_share_experts_fusion = (
            global_server_args_dict["n_share_experts_fusion"]
            if global_server_args_dict["n_share_experts_fusion"] is not None
            else 0
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
182
183
184
185
186
187
188
189
190
191
192
193
194
        self.routed_scaling_factor = config.routed_scaling_factor
        if self.tp_size > config.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.n_routed_experts}."
            )

        if config.hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

195
        self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
Ke Bao's avatar
Ke Bao committed
196

197
198
199
200
201
        MoEImpl = (
            DeepEPMoE
            if global_server_args_dict["enable_deepep_moe"]
            else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
        )
202

203
        self.experts = MoEImpl(
204
205
            num_experts=config.n_routed_experts + self.n_share_experts_fusion,
            top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            topk_group=config.topk_group,
            correction_bias=self.gate.e_score_correction_bias,
            prefix=add_prefix("experts", prefix),
            **(
                dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
                if global_server_args_dict["enable_deepep_moe"]
                else {}
            ),
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
221

222
        if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
223
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
            # disable tp for shared experts when enable deepep moe
            if not global_server_args_dict["enable_deepep_moe"]:
                self.shared_experts = DeepseekV2MLP(
                    hidden_size=config.hidden_size,
                    intermediate_size=intermediate_size,
                    hidden_act=config.hidden_act,
                    quant_config=quant_config,
                    reduce_results=False,
                    prefix=add_prefix("shared_experts", prefix),
                )
            else:
                self.shared_experts = DeepseekV2MLP(
                    hidden_size=config.hidden_size,
                    intermediate_size=intermediate_size,
                    hidden_act=config.hidden_act,
                    quant_config=quant_config,
                    reduce_results=False,
                    prefix=add_prefix("shared_experts", prefix),
                    tp_rank=0,
                    tp_size=1,
                )

        if global_server_args_dict["enable_deepep_moe"]:
247
248
            # TODO: we will support tp < ep in the future
            self.ep_size = get_tensor_model_parallel_world_size()
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
            self.num_experts = config.n_routed_experts
            self.top_k = config.num_experts_per_tok
            self.renormalize = config.norm_topk_prob
            self.topk_group = config.topk_group
            self.num_expert_group = config.n_group
            self.correction_bias = (
                self.gate.e_score_correction_bias.data
                if self.gate.e_score_correction_bias is not None
                else None
            )

            self.deepep_dispatcher = DeepEPDispatcher(
                group=parallel_state.get_tp_group().device_group,
                router_topk=self.top_k,
                permute_fusion=True,
                num_experts=config.n_routed_experts,
                num_local_experts=config.n_routed_experts // self.tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
266
                hidden_size=config.hidden_size,
267
                params_dtype=config.torch_dtype,
268
                deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
fzyzcjy's avatar
fzyzcjy committed
269
                async_finish=True,  # TODO
270
                return_recv_hook=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
271
272
            )

273
274
275
276
277
278
279
280
281
    def forward(
        self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
    ) -> torch.Tensor:
        if not global_server_args_dict["enable_deepep_moe"]:
            return self.forward_normal(hidden_states)
        else:
            return self.forward_deepep(hidden_states, forward_mode)

    def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
282
        if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
283
            shared_output = self.shared_experts(hidden_states)
284
285
        else:
            shared_output = None
Liangsheng Yin's avatar
Liangsheng Yin committed
286
        # router_logits: (num_tokens, n_experts)
Ke Bao's avatar
Ke Bao committed
287
        router_logits = self.gate(hidden_states)
Liangsheng Yin's avatar
Liangsheng Yin committed
288
289
290
291
292
293
294
295
        final_hidden_states = (
            self.experts(hidden_states=hidden_states, router_logits=router_logits)
            * self.routed_scaling_factor
        )
        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
fzyzcjy's avatar
fzyzcjy committed
296
        return final_hidden_states
297
298
299
300
301
302
303
304
305
306
307

    def forward_deepep(
        self, hidden_states: torch.Tensor, forward_mode: ForwardMode
    ) -> torch.Tensor:
        shared_output = None
        topk_idx = torch.full(
            (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
        )
        topk_weights = torch.empty(
            (0, self.top_k), dtype=torch.float32, device=hidden_states.device
        )
308
309
310
311
312
        if (
            forward_mode is not None
            and not forward_mode.is_idle()
            and hidden_states.shape[0] > 0
        ):
313
314
315
316
317
318
319
320
321
322
323
324
325
326
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
            if self.n_shared_experts is not None:
                shared_output = self.shared_experts(hidden_states)
            topk_weights, topk_idx = select_experts(
                hidden_states=hidden_states,
                router_logits=router_logits,
                top_k=self.top_k,
                use_grouped_topk=True,
                renormalize=self.renormalize,
                topk_group=self.topk_group,
                num_expert_group=self.num_expert_group,
                correction_bias=self.correction_bias,
            )
327
        if self.ep_size > 1:
328
            # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
329
330
331
332
333
334
335
336
337
338
339
340
341
            (
                hidden_states,
                topk_idx,
                topk_weights,
                reorder_topk_ids,
                seg_indptr,
                masked_m,
                expected_m,
            ) = self.deepep_dispatcher.dispatch(
                hidden_states,
                topk_idx,
                topk_weights,
                forward_mode=forward_mode,
342
343
344
            )
        final_hidden_states = (
            self.experts(
345
                hidden_states=hidden_states,
346
347
                reorder_topk_ids=reorder_topk_ids,
                seg_indptr=seg_indptr,
348
349
                masked_m=masked_m,
                expected_m=expected_m,
350
351
352
353
                forward_mode=forward_mode,
            )
            * self.routed_scaling_factor
        )
354
        if self.ep_size > 1:
355
            final_hidden_states = self.deepep_dispatcher.combine(
356
357
358
359
                final_hidden_states,
                topk_idx,
                topk_weights,
                forward_mode,
360
361
362
            )
        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output
Liangsheng Yin's avatar
Liangsheng Yin committed
363

fzyzcjy's avatar
fzyzcjy committed
364
        return final_hidden_states
Liangsheng Yin's avatar
Liangsheng Yin committed
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math

    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


class DeepseekV2Attention(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
        layer_id=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
392
        reduce_results: bool = True,
393
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
394
395
396
397
398
399
400
401
402
403
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
Lianmin Zheng's avatar
Lianmin Zheng committed
404
405
406
407
408

        self.dp_size = get_attention_dp_size()
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

Liangsheng Yin's avatar
Liangsheng Yin committed
409
        self.num_heads = num_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
410
411
        assert num_heads % attn_tp_size == 0
        self.num_local_heads = num_heads // attn_tp_size
Liangsheng Yin's avatar
Liangsheng Yin committed
412
413
414
415
416
417
418
419
420
421
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
            self.q_a_proj = ReplicatedLinear(
                self.hidden_size,
                self.q_lora_rank,
                bias=False,
                quant_config=quant_config,
422
                prefix=add_prefix("q_a_proj", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
423
424
425
426
427
428
429
            )
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
430
                prefix=add_prefix("q_b_proj", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
431
432
433
434
435
436
437
            )
        else:
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
438
                prefix=add_prefix("q_proj", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
439
440
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
441
442
443
444
445
446
447
            )

        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
448
            prefix=add_prefix("kv_a_proj_with_mqa", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
449
450
451
452
453
454
455
        )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
456
            prefix=add_prefix("kv_b_proj", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
457
458
459
460
461
462
463
        )
        # O projection.
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
464
            prefix=add_prefix("o_proj", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
465
466
467
            reduce_results=reduce_results,
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
468
        )
469
        rope_scaling["rope_type"] = "deepseek_yarn"
470
        self.rotary_emb = get_rope_wrapper(
Liangsheng Yin's avatar
Liangsheng Yin committed
471
472
473
474
475
476
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
477
            device=global_server_args_dict["device"],
Liangsheng Yin's avatar
Liangsheng Yin committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
        )

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

        # TODO, support head_size 192
        self.attn = RadixAttention(
            self.num_local_heads,
            256,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            layer_id=layer_id,
493
            quant_config=quant_config,
494
            prefix=add_prefix("attn", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
495
496
497
498
499
500
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
501
        forward_batch: ForwardBatch,
Liangsheng Yin's avatar
Liangsheng Yin committed
502
    ) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
503
504
505
506
507
508
        if hidden_states.shape[0] == 0:
            assert (
                not self.o_proj.reduce_results
            ), "short-circuiting allreduce will lead to hangs"
            return hidden_states

Liangsheng Yin's avatar
Liangsheng Yin committed
509
510
511
512
513
514
515
516
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
517
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
Liangsheng Yin's avatar
Liangsheng Yin committed
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe
        q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view(
            -1, self.num_local_heads * 256
        )
        k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view(
            -1, self.num_local_heads * 256
        )
        v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
            -1, self.num_local_heads * 256
        )
540
        attn_output = self.attn(q, k, v, forward_batch)
Liangsheng Yin's avatar
Liangsheng Yin committed
541
542
543
544
545
546
547
        attn_output = attn_output.view(-1, self.num_local_heads, 256)[
            ..., : self.v_head_dim
        ].reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output


548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
class DeepseekV2AttentionMLA(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
564
565
        reduce_results: bool = True,
        layer_id: int = None,
566
        prefix: str = "",
567
568
569
570
571
572
573
574
575
576
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
Lianmin Zheng's avatar
Lianmin Zheng committed
577
578
579
580
        self.dp_size = get_attention_dp_size()
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

581
        self.num_heads = num_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
582
583
        assert num_heads % attn_tp_size == 0
        self.num_local_heads = num_heads // attn_tp_size
584
585
586
587
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

Lianmin Zheng's avatar
Lianmin Zheng committed
588
589
590
        # For tensor parallel attention
        if self.q_lora_rank is not None:
            self.q_a_proj = ReplicatedLinear(
Ke Bao's avatar
Ke Bao committed
591
                self.hidden_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
592
                self.q_lora_rank,
593
594
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
595
                prefix=add_prefix("q_a_proj", prefix),
596
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
597
598
599
600
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
Ke Bao's avatar
Ke Bao committed
601
602
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
603
604
605
                prefix=add_prefix("q_b_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
Ke Bao's avatar
Ke Bao committed
606
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
607
608
        else:
            self.q_proj = ColumnParallelLinear(
609
                self.hidden_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
610
                self.num_heads * self.qk_head_dim,
611
612
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
613
614
615
                prefix=add_prefix("q_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
616
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("kv_b_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
        # O projection.
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=add_prefix("o_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
637
638
639
640
641
642

        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
643
            prefix=add_prefix("kv_a_proj_with_mqa", prefix),
644
645
        )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
Ke Bao's avatar
Ke Bao committed
646
647
648
649

        if rope_scaling:
            rope_scaling["rope_type"] = "deepseek_yarn"

650
        self.rotary_emb = get_rope(
651
652
653
654
655
656
657
658
659
660
661
662
663
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
        )

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale
Ke Bao's avatar
Ke Bao committed
664
665
        else:
            self.rotary_emb.forward = self.rotary_emb.forward_native
666

667
        self.attn_mqa = RadixAttention(
668
669
670
671
672
673
            self.num_local_heads,
            self.kv_lora_rank + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=1,
            layer_id=layer_id,
            v_head_dim=self.kv_lora_rank,
674
            quant_config=quant_config,
675
            prefix=add_prefix("attn_mqa", prefix),
676
677
        )

678
679
680
681
682
683
684
        self.attn_mha = RadixAttention(
            self.num_local_heads,
            self.qk_nope_head_dim + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            layer_id=layer_id,
            v_head_dim=self.v_head_dim,
685
            quant_config=quant_config,
686
            prefix=add_prefix("attn_mha", prefix),
687
688
        )

Ke Bao's avatar
Ke Bao committed
689
690
        self.w_kc = None
        self.w_vc = None
691
        self.w_scale = None
692

Lianmin Zheng's avatar
Lianmin Zheng committed
693
694
695
        self.flashinfer_mla_disable_ragged = global_server_args_dict[
            "flashinfer_mla_disable_ragged"
        ]
696
        self.attention_backend = global_server_args_dict["attention_backend"]
Lianmin Zheng's avatar
Lianmin Zheng committed
697
698
699
        self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"

    def no_absorb(self, forward_batch: ForwardBatch) -> bool:
700
        if self.attention_backend == "flashinfer":
Lianmin Zheng's avatar
Lianmin Zheng committed
701
702
703
704
705
706
            # Flashinfer MLA: Do not absorb when enabling ragged prefill
            return (
                not self.flashinfer_mla_disable_ragged
                and forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
707
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
Lianmin Zheng's avatar
Lianmin Zheng committed
708
            )
709
710
711
        elif self.attention_backend == "fa3":
            # Flash Attention: Keep absorbing for all extend/decode
            return False
Lianmin Zheng's avatar
Lianmin Zheng committed
712
713
714
715
716
717
        else:
            # Triton: Use normal computation for prefill and use weight absorption for extend/decode
            return (
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
718
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
Lianmin Zheng's avatar
Lianmin Zheng committed
719
720
            )

721
722
723
724
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
725
        forward_batch: ForwardBatch,
726
    ) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
727
728
729
730
731
        if hidden_states.shape[0] == 0:
            assert (
                not self.o_proj.reduce_results
            ), "short-circuiting allreduce will lead to hangs"
            return hidden_states
732

Lianmin Zheng's avatar
Lianmin Zheng committed
733
        if self.no_absorb(forward_batch):
734
            return self.forward_normal(positions, hidden_states, forward_batch)
735
        else:
736
            if _is_hip:
737
                if (
Lianmin Zheng's avatar
Lianmin Zheng committed
738
                    self.rocm_fused_decode_mla
739
740
741
742
743
744
745
746
747
                    and forward_batch.forward_mode.is_decode()
                ):
                    return self.forward_absorb_fused_mla_rope(
                        positions, hidden_states, forward_batch
                    )
                else:
                    return self.forward_absorb(positions, hidden_states, forward_batch)
            else:
                return self.forward_absorb(positions, hidden_states, forward_batch)
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795

    def forward_normal(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

    def forward_absorb(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
796
797
798
799
800
801
802
803
804
805
806
807
808
809
    ) -> torch.Tensor:
        q_len = hidden_states.shape[0]
        q_input = hidden_states.new_empty(
            q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
        )
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
810

811
        if self.w_kc.dtype == torch.float8_e4m3fnuz:
812
813
814
815
816
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
817
        elif self.w_kc.dtype == torch.float8_e4m3fn:
818
819
820
821
822
823
824
825
826
            q_nope_val, q_nope_scale = input_to_float8(
                q_nope.transpose(0, 1), torch.float8_e4m3fn
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
        q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
827

Ke Bao's avatar
Ke Bao committed
828
829
830
831
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
        v_input = latent_cache[..., : self.kv_lora_rank]
        v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
832
        k_input[..., : self.kv_lora_rank] = v_input
Ke Bao's avatar
Ke Bao committed
833
        k_pe = k_input[..., self.kv_lora_rank :]
834
835
836
837
838

        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q_input[..., self.kv_lora_rank :] = q_pe
        k_input[..., self.kv_lora_rank :] = k_pe

839
        attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
840
841
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

842
        if self.w_vc.dtype == torch.float8_e4m3fnuz:
843
844
845
846
847
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
848
        elif self.w_vc.dtype == torch.float8_e4m3fn:
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
            attn_output_val, attn_output_scale = input_to_float8(
                attn_output.transpose(0, 1), torch.float8_e4m3fn
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
        output, _ = self.o_proj(attn_output)

        return output

    def forward_absorb_fused_mla_rope(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        enable_rope_fusion = (
            os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
        )
        q_len = hidden_states.shape[0]
        q_input = hidden_states.new_empty(
            q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
        )
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        if self.w_kc.dtype == torch.float8_e4m3fnuz:
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_kc.dtype == torch.float8_e4m3fn:
            q_nope_val, q_nope_scale = input_to_float8(
                q_nope.transpose(0, 1), torch.float8_e4m3fn
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
        q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)

        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
        v_input = latent_cache[..., : self.kv_lora_rank]
        v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
        k_input[..., : self.kv_lora_rank] = v_input

        if not enable_rope_fusion:
            k_pe = k_input[..., self.kv_lora_rank :]
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
            q_input[..., self.kv_lora_rank :] = q_pe
            k_input[..., self.kv_lora_rank :] = k_pe
            k_pe_output = None
        else:
            k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])

        q_input[..., self.kv_lora_rank :] = q_pe

        # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        # Use Fused ROPE with use_rope=OFF.
        attn_output = torch.empty(
            (q_len, self.num_local_heads, self.kv_lora_rank),
            dtype=q.dtype,
            device=q.device,
        )
        attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
            forward_batch.attn_backend.forward_metadata
        )
        cos_sin_cache = self.rotary_emb.cos_sin_cache
        num_kv_split = forward_batch.attn_backend.num_kv_splits
        sm_scale = self.attn_mqa.scaling
        if attn_logits is None:
            attn_logits = torch.empty(
                (
                    forward_batch.batch_size,
                    self.num_local_heads,
                    num_kv_split,
                    self.kv_lora_rank + 1,
                ),
                dtype=torch.float32,
                device=q.device,
            )

        # save current latent cache.
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mqa, forward_batch.out_cache_loc, k_input, None
        )
        key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
            self.attn_mqa.layer_id
        )
        val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]

        decode_attention_fwd_grouped_rope(
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            self.kv_lora_rank,
            self.rotary_emb.rotary_dim,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            logit_cap=self.attn_mqa.logit_cap,
            use_rope=enable_rope_fusion,
            is_neox_style=self.rotary_emb.is_neox_style,
        )

        if enable_rope_fusion:
            k_input[..., self.kv_lora_rank :] = k_pe_output
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mqa, forward_batch.out_cache_loc, k_input, None
            )

        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

985
986
987
988
989
990
991
        if self.w_vc.dtype == torch.float8_e4m3fnuz:
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_vc.dtype == torch.float8_e4m3fn:
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
            attn_output_val, attn_output_scale = input_to_float8(
                attn_output.transpose(0, 1), torch.float8_e4m3fn
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1005
1006
1007
1008
1009
        output, _ = self.o_proj(attn_output)

        return output


Liangsheng Yin's avatar
Liangsheng Yin committed
1010
1011
1012
1013
1014
1015
1016
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_id: int,
        quant_config: Optional[QuantizationConfig] = None,
1017
        is_nextn: bool = False,
1018
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1019
    ) -> None:
1020
1021
1022
1023
1024
1025
1026
1027

        def is_sparse_layer(l: int):
            return (
                config.n_routed_experts is not None
                and l >= config.first_k_dense_replace
                and l % config.moe_layer_freq == 0
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
1028
1029
1030
1031
1032
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
Lianmin Zheng's avatar
Lianmin Zheng committed
1033
1034
1035
        self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
        self.layer_id = layer_id
        self.dp_size = get_attention_dp_size()
1036
1037
        self.attn_tp_size = get_attention_tp_size()
        self.attn_tp_rank = get_attention_tp_rank()
Lianmin Zheng's avatar
Lianmin Zheng committed
1038

Ke Bao's avatar
Ke Bao committed
1039
        if not global_server_args_dict["disable_mla"]:
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
            self.self_attn = DeepseekV2AttentionMLA(
                config=config,
                hidden_size=self.hidden_size,
                num_heads=config.num_attention_heads,
                qk_nope_head_dim=config.qk_nope_head_dim,
                qk_rope_head_dim=config.qk_rope_head_dim,
                v_head_dim=config.v_head_dim,
                q_lora_rank=(
                    config.q_lora_rank if hasattr(config, "q_lora_rank") else None
                ),
                kv_lora_rank=config.kv_lora_rank,
                rope_theta=rope_theta,
                rope_scaling=rope_scaling,
                max_position_embeddings=max_position_embeddings,
                quant_config=quant_config,
                layer_id=layer_id,
Lianmin Zheng's avatar
Lianmin Zheng committed
1056
                reduce_results=False,
1057
                prefix=add_prefix("self_attn", prefix),
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
            )
        else:
            self.self_attn = DeepseekV2Attention(
                config=config,
                hidden_size=self.hidden_size,
                num_heads=config.num_attention_heads,
                qk_nope_head_dim=config.qk_nope_head_dim,
                qk_rope_head_dim=config.qk_rope_head_dim,
                v_head_dim=config.v_head_dim,
                q_lora_rank=(
                    config.q_lora_rank if hasattr(config, "q_lora_rank") else None
                ),
                kv_lora_rank=config.kv_lora_rank,
                rope_theta=rope_theta,
                rope_scaling=rope_scaling,
                max_position_embeddings=max_position_embeddings,
                quant_config=quant_config,
                layer_id=layer_id,
Lianmin Zheng's avatar
Lianmin Zheng committed
1076
                reduce_results=False,
1077
                prefix=add_prefix("self_attn", prefix),
1078
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1079

1080
        if is_nextn or is_sparse_layer(layer_id):
1081
1082
1083
1084
1085
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=add_prefix("mlp", prefix),
            )
1086
            self.is_sparse = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1087
1088
1089
1090
1091
1092
        else:
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1093
                prefix=add_prefix("mlp", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
1094
            )
1095
1096
1097
1098
1099
1100
1101
1102
            self.is_sparse = False

        self.input_is_scattered = (
            is_sparse_layer(layer_id - 1)
            and global_server_args_dict["enable_deepep_moe"]
        )
        self.is_last_layer = self.layer_id == config.num_hidden_layers - 1

Liangsheng Yin's avatar
Liangsheng Yin committed
1103
1104
1105
1106
1107
1108
1109
1110
1111
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1112
        forward_batch: ForwardBatch,
Liangsheng Yin's avatar
Liangsheng Yin committed
1113
1114
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
        if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
            return self.forward_deepep(
                positions, hidden_states, forward_batch, residual
            )
        else:
            return self.forward_normal(
                positions, hidden_states, forward_batch, residual
            )

    def forward_normal(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:

1132
        if hidden_states.shape[0] == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1133
1134
            residual = hidden_states
        else:
1135
1136
1137
1138
1139
            if residual is None:
                residual = hidden_states
                hidden_states = self.input_layernorm(hidden_states)
            else:
                hidden_states, residual = self.input_layernorm(hidden_states, residual)
Lianmin Zheng's avatar
Lianmin Zheng committed
1140

1141
1142
1143
1144
            assert not (
                self.attn_tp_size != 1 and self.input_is_scattered
            ), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"

1145
1146
1147
1148
1149
            # Self Attention
            hidden_states = self.self_attn(
                positions=positions,
                hidden_states=hidden_states,
                forward_batch=forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
1150
1151
1152
1153
1154
1155
            )

        # Gather
        if get_tensor_model_parallel_world_size() > 1:
            # all gather and all reduce
            if self.dp_size != 1:
1156
1157
1158
1159
1160
1161
1162
1163
1164
                if self.attn_tp_rank == 0:
                    hidden_states += residual
                hidden_states, local_hidden_states = (
                    forward_batch.gathered_buffer,
                    hidden_states,
                )
                dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
                dp_scatter(residual, hidden_states, forward_batch)
                hidden_states = self.post_attention_layernorm(hidden_states)
Ke Bao's avatar
Ke Bao committed
1165
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1166
                hidden_states = tensor_model_parallel_all_reduce(hidden_states)
1167
1168
1169
1170
1171
1172
1173
                hidden_states, residual = self.post_attention_layernorm(
                    hidden_states, residual
                )
        else:
            hidden_states, residual = self.post_attention_layernorm(
                hidden_states, residual
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
1174
1175

        # Fully Connected
Lianmin Zheng's avatar
Lianmin Zheng committed
1176
        hidden_states = self.mlp(hidden_states)
1177

1178
        # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
        # Scatter
        if self.dp_size != 1:
            # important: forward batch.gathered_buffer is used both after scatter and after gather.
            # be careful about this!
            hidden_states, global_hidden_states = (
                forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
                hidden_states,
            )
            dp_scatter(hidden_states, global_hidden_states, forward_batch)

Liangsheng Yin's avatar
Liangsheng Yin committed
1189
1190
        return hidden_states, residual

1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
    def forward_deepep(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:

        if hidden_states.shape[0] == 0:
            residual = hidden_states
        else:
            if residual is None:
                residual = hidden_states
                hidden_states = self.input_layernorm(hidden_states)
            else:
                hidden_states, residual = self.input_layernorm(hidden_states, residual)

        if self.attn_tp_size != 1 and self.input_is_scattered:
            hidden_states, local_hidden_states = (
                forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
                hidden_states,
            )
            tp_all_gather(
                list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
            )

        # Self Attention
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
        )

        if self.attn_tp_size != 1:
            if self.input_is_scattered:
                tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
                hidden_states = tensor_list[self.attn_tp_rank]
                tp_reduce_scatter(hidden_states, tensor_list)
                if hidden_states.shape[0] != 0:
                    hidden_states, residual = self.post_attention_layernorm(
                        hidden_states, residual
                    )
            else:
                if self.attn_tp_rank == 0:
                    hidden_states += residual
                tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
                hidden_states = tensor_list[self.attn_tp_rank]
                tp_reduce_scatter(hidden_states, tensor_list)
                residual = hidden_states
                if hidden_states.shape[0] != 0:
                    hidden_states = self.post_attention_layernorm(hidden_states)
        else:
            if hidden_states.shape[0] != 0:
                hidden_states, residual = self.post_attention_layernorm(
                    hidden_states, residual
                )
        hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)

        if self.is_last_layer and self.attn_tp_size != 1:
1250
1251
            hidden_states += residual
            residual = None
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
            hidden_states, local_hidden_states = (
                forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
                hidden_states,
            )
            tp_all_gather(
                list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
            )

        return hidden_states, residual

Liangsheng Yin's avatar
Liangsheng Yin committed
1262
1263
1264
1265
1266
1267
1268
1269

class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1270
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1271
1272
1273
1274
1275
1276
1277
1278
    ) -> None:
        super().__init__()
        self.padding_id = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
Ke Bao's avatar
Ke Bao committed
1279
            enable_tp=not global_server_args_dict["enable_dp_attention"],
Liangsheng Yin's avatar
Liangsheng Yin committed
1280
1281
1282
1283
1284
1285
1286
        )
        self.layers = nn.ModuleList(
            [
                DeepseekV2DecoderLayer(
                    config,
                    layer_id,
                    quant_config=quant_config,
1287
                    prefix=add_prefix(f"layers.{layer_id}", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
1288
1289
1290
1291
1292
1293
                )
                for layer_id in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Lianmin Zheng's avatar
Lianmin Zheng committed
1294
1295
        self.dp_size = get_attention_dp_size()

Liangsheng Yin's avatar
Liangsheng Yin committed
1296
1297
1298
1299
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1300
        forward_batch: ForwardBatch,
1301
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1302
    ) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
1303

1304
1305
1306
1307
1308
        if input_embeds is None:
            hidden_states = self.embed_tokens(input_ids)
        else:
            hidden_states = input_embeds

Liangsheng Yin's avatar
Liangsheng Yin committed
1309
1310
        residual = None
        for i in range(len(self.layers)):
1311
            expert_distribution_recorder.set_current_layer(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
1312
1313
            layer = self.layers[i]
            hidden_states, residual = layer(
1314
                positions, hidden_states, forward_batch, residual
Liangsheng Yin's avatar
Liangsheng Yin committed
1315
            )
Ke Bao's avatar
Ke Bao committed
1316
        if not forward_batch.forward_mode.is_idle():
1317
1318
1319
1320
            if residual is None:
                hidden_states = self.norm(hidden_states)
            else:
                hidden_states, _ = self.norm(hidden_states, residual)
Liangsheng Yin's avatar
Liangsheng Yin committed
1321
1322
1323
1324
1325
1326
1327
1328
1329
        return hidden_states


class DeepseekV2ForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1330
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1331
1332
1333
    ) -> None:
        super().__init__()
        self.config = config
1334
        self.tp_size = get_tensor_model_parallel_world_size()
Liangsheng Yin's avatar
Liangsheng Yin committed
1335
        self.quant_config = quant_config
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
        self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
        # Only Deepseek V3/R1 can use shared experts fusion optimization now.
        if (
            global_server_args_dict.get("disable_shared_experts_fusion", False)
            or self.config.architectures[0] != "DeepseekV3ForCausalLM"
            or self.config.n_routed_experts != 256
            or self.config.routed_scaling_factor != 2.5
        ):
            self.n_share_experts_fusion = None
            global_server_args_dict["n_share_experts_fusion"] = None
            logger.info(
                "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
            )
        elif self.n_share_experts_fusion is None:
            global_server_args_dict["n_share_experts_fusion"] = self.tp_size
            self.n_share_experts_fusion = self.tp_size
            logger.info(
                f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
            )

1356
1357
1358
        self.model = DeepseekV2Model(
            config, quant_config, prefix=add_prefix("model", prefix)
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1359
1360
1361
1362
1363
1364
1365
1366
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=add_prefix("lm_head", prefix),
        )
        self.logits_processor = LogitsProcessor(config)
        self.dp_size = get_attention_dp_size()
Liangsheng Yin's avatar
Liangsheng Yin committed
1367

Mick's avatar
Mick committed
1368
1369
1370
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.embed_tokens

1371
    @torch.no_grad()
Liangsheng Yin's avatar
Liangsheng Yin committed
1372
1373
1374
1375
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1376
        forward_batch: ForwardBatch,
1377
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1378
    ) -> torch.Tensor:
1379
1380

        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
Lianmin Zheng's avatar
Lianmin Zheng committed
1381

1382
1383
1384
        return self.logits_processor(
            input_ids, hidden_states, self.lm_head, forward_batch
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1385

inkcherry's avatar
inkcherry committed
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
    def post_load_weights(self):

        # Perform post-processing after loading weights

        if not global_server_args_dict["disable_mla"]:
            for layer_id in range(self.config.num_hidden_layers):
                self_attn = self.model.layers[layer_id].self_attn
                if hasattr(self_attn.kv_b_proj, "qweight"):
                    # AWQ compatible
                    if _is_cuda:
                        w = awq_dequantize(
                            self_attn.kv_b_proj.qweight,
                            self_attn.kv_b_proj.scales,
                            self_attn.kv_b_proj.qzeros,
                        ).T
                    else:
                        w = ops.awq_dequantize(
                            self_attn.kv_b_proj.qweight,
                            self_attn.kv_b_proj.scales,
                            self_attn.kv_b_proj.qzeros,
                            0,
                            0,
                            0,
                        ).T
                else:
                    w = self_attn.kv_b_proj.weight
                # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
                # This may affect the accuracy of fp8 model.
                if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
                    torch.float8_e4m3fn,
                    torch.float8_e4m3fnuz,
                ):
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        if _is_hip:
                            weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                                weight=w,
                                weight_scale=self_attn.kv_b_proj.weight_scale_inv,
                                input_scale=None,
                            )
                        else:
                            weight = w
                            weight_scale = self_attn.kv_b_proj.weight_scale_inv

                        w, scale = block_quant_to_tensor_quant(
                            weight, weight_scale, weight_block_size
                        )
                        self_attn.w_scale = scale
                if w.dtype == torch.int8:
                    if hasattr(self.quant_config, "weight_block_size"):
                        # block-wise int8 need it
                        weight_block_size = self.quant_config.weight_block_size
                        if weight_block_size is not None:
                            assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                            weight = w
                            weight_scale = self_attn.kv_b_proj.weight_scale_inv
                            w = int8_block_dequant(
                                weight, weight_scale, weight_block_size
                            ).to(torch.bfloat16)
                    else:
                        # channel-wise int8 need it
                        w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
                            torch.bfloat16
                        )
                w_kc, w_vc = w.unflatten(
                    0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
                ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
                self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
                self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
                if (
                    hasattr(self_attn.kv_b_proj, "weight_scale")
                    and self_attn.w_scale is None
                ):
                    self_attn.w_scale = self_attn.kv_b_proj.weight_scale
                    if _is_hip:
                        self_attn.w_scale *= 2.0

Liangsheng Yin's avatar
Liangsheng Yin committed
1464
1465
1466
1467
1468
1469
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
        if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
            weights_list = list(weights)
            weights_dict = dict(weights_list)
            suffix_list = [
                "down_proj.weight",
                "down_proj.weight_scale_inv",
                "gate_proj.weight",
                "gate_proj.weight_scale_inv",
                "up_proj.weight",
                "up_proj.weight_scale_inv",
            ]
            names_to_remove = []
            for moe_layer in tqdm(
                range(
                    self.config.first_k_dense_replace,
                    self.config.num_hidden_layers,
                    self.config.moe_layer_freq,
                ),
                desc=f"Cloning {self.n_share_experts_fusion} "
                "replicas of the shared expert into MoE",
            ):
                for num_repeat in range(self.n_share_experts_fusion):
                    for suffix in suffix_list:
                        shared_expert_weight_name = (
                            f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
                        )
                        weights_list.append(
                            (
                                f"model.layers.{moe_layer}."
                                f"mlp.experts."
                                f"{self.config.n_routed_experts + num_repeat}"
                                f".{suffix}",
                                weights_dict[shared_expert_weight_name].clone(),
                            )
                        )
                        names_to_remove += [shared_expert_weight_name]
            weights = [w for w in weights_list if w[0] not in names_to_remove]
Liangsheng Yin's avatar
Liangsheng Yin committed
1507
1508
1509

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1510
1511
1512
1513
1514
        MoEImpl = (
            DeepEPMoE
            if global_server_args_dict["enable_deepep_moe"]
            else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
        )
xiaobochen's avatar
xiaobochen committed
1515
        expert_params_mapping = MoEImpl.make_expert_params_mapping(
Liangsheng Yin's avatar
Liangsheng Yin committed
1516
1517
1518
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1519
1520
1521
1522
1523
1524
            num_experts=self.config.n_routed_experts
            + (
                self.n_share_experts_fusion
                if self.n_share_experts_fusion is not None
                else 0
            ),
Liangsheng Yin's avatar
Liangsheng Yin committed
1525
1526
1527
1528
        )

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
HandH1998's avatar
HandH1998 committed
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
            # TODO(HandH1998): Modify it when nextn is supported.
            if hasattr(self.config, "num_nextn_predict_layers"):
                num_nextn_layers = self.config.num_nextn_predict_layers
                if num_nextn_layers > 0 and name.startswith("model.layers"):
                    name_list = name.split(".")
                    if (
                        len(name_list) >= 3
                        and int(name_list[2]) >= self.config.num_hidden_layers
                    ):
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
            if "rotary_emb.inv_freq" in name:
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if ("mlp.experts." in name) and name not in params_dict:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(
                        param,
                        loaded_weight,
1572
                        name,
Liangsheng Yin's avatar
Liangsheng Yin committed
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)

inkcherry's avatar
inkcherry committed
1588
        self.post_load_weights()
Ke Bao's avatar
Ke Bao committed
1589

1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

Liangsheng Yin's avatar
Liangsheng Yin committed
1601

HandH1998's avatar
HandH1998 committed
1602
1603
1604
1605
1606
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass


EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]