deepseek_v2.py 84.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Liangsheng Yin's avatar
Liangsheng Yin committed
15
16
17
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
18

19
import logging
20
import os
21
from enum import IntEnum, auto
Liangsheng Yin's avatar
Liangsheng Yin committed
22
23
24
from typing import Any, Dict, Iterable, Optional, Tuple

import torch
Ke Bao's avatar
Ke Bao committed
25
import torch.nn.functional as F
Liangsheng Yin's avatar
Liangsheng Yin committed
26
from torch import nn
27
from tqdm import tqdm
Liangsheng Yin's avatar
Liangsheng Yin committed
28
from transformers import PretrainedConfig
29
30

from sglang.srt.distributed import (
Liangsheng Yin's avatar
Liangsheng Yin committed
31
    get_tensor_model_parallel_world_size,
32
    parallel_state,
Liangsheng Yin's avatar
Liangsheng Yin committed
33
34
    tensor_model_parallel_all_reduce,
)
35
from sglang.srt.layers.activation import SiluAndMul
36
37
38
39
40
from sglang.srt.layers.communicator import (
    LayerCommunicator,
    LayerScatterModes,
    enable_moe_dense_fully_dp,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
41
42
43
from sglang.srt.layers.dp_attention import (
    get_attention_tp_rank,
    get_attention_tp_size,
44
    get_local_attention_dp_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
45
)
46
from sglang.srt.layers.layernorm import RMSNorm
47
48
49
50
51
52
from sglang.srt.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
53
from sglang.srt.layers.logits_processor import LogitsProcessor
fzyzcjy's avatar
fzyzcjy committed
54
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
Lianmin Zheng's avatar
Lianmin Zheng committed
55
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
56
from sglang.srt.layers.moe.topk import select_experts
57
from sglang.srt.layers.quantization import deep_gemm_wrapper
58
from sglang.srt.layers.quantization.base_config import QuantizationConfig
59
from sglang.srt.layers.quantization.fp8_kernel import (
60
    is_fp8_fnuz,
61
    per_tensor_quant_mla_fp8,
62
    per_token_group_quant_mla_deep_gemm_masked_fp8,
63
)
HandH1998's avatar
HandH1998 committed
64
from sglang.srt.layers.quantization.fp8_utils import (
65
    block_quant_dequant,
HandH1998's avatar
HandH1998 committed
66
    block_quant_to_tensor_quant,
67
    channel_quant_to_tensor_quant,
68
    normalize_e4m3fn_to_e4m3fnuz,
69
    requant_weight_ue8m0_inplace,
HandH1998's avatar
HandH1998 committed
70
)
71
72
73
from sglang.srt.layers.quantization.int8_utils import (
    block_dequant as int8_block_dequant,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
74
from sglang.srt.layers.radix_attention import RadixAttention
75
from sglang.srt.layers.rotary_embedding import get_rope
76
77
78
79
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
80
81
82
83
from sglang.srt.managers.expert_distribution import (
    get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
84
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
85
from sglang.srt.managers.schedule_batch import global_server_args_dict
86
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
87
from sglang.srt.model_loader.weight_utils import default_weight_loader
88
89
90
91
from sglang.srt.two_batch_overlap import (
    MaybeTboDeepEPDispatcher,
    model_forward_maybe_tbo,
)
92
93
94
from sglang.srt.utils import (
    BumpAllocator,
    DeepEPMode,
95
    LazyValue,
96
    add_prefix,
97
    bind_or_assign,
98
99
100
101
    get_bool_env_var,
    get_int_env_var,
    is_cuda,
    is_hip,
102
    is_non_idle_and_non_empty,
103
    log_info_on_rank0,
104
)
105

106
_is_hip = is_hip()
Yineng Zhang's avatar
Yineng Zhang committed
107
_is_cuda = is_cuda()
108
_is_fp8_fnuz = is_fp8_fnuz()
109
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
110

Yineng Zhang's avatar
Yineng Zhang committed
111
if _is_cuda:
112
    from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
Yineng Zhang's avatar
Yineng Zhang committed
113
else:
Lianmin Zheng's avatar
Lianmin Zheng committed
114
    from vllm._custom_ops import awq_dequantize
Liangsheng Yin's avatar
Liangsheng Yin committed
115

116
117
118
119
120
if _is_hip:
    from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
        decode_attention_fwd_grouped_rope,
    )

121
122
123
if _use_aiter:
    from aiter.rotary_embedding import get_rope

124
125
logger = logging.getLogger(__name__)

Liangsheng Yin's avatar
Liangsheng Yin committed
126

127
128
129
130
131
132
133
134
135
136
137
class AttnForwardMethod(IntEnum):
    # Use multi-head attention
    MHA = auto()

    # Use absorbed multi-latent attention
    MLA = auto()

    # Use multi-head attention, but with KV cache chunked.
    # This method can avoid OOM when prefix lengths are long.
    MHA_CHUNKED_KV = auto()

138
139
140
    # Use MLA but with fused RoPE
    MLA_FUSED_ROPE = auto()

141

Liangsheng Yin's avatar
Liangsheng Yin committed
142
143
144
145
146
147
148
149
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
150
        prefix: str = "",
151
152
        tp_rank: Optional[int] = None,
        tp_size: Optional[int] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
153
154
    ) -> None:
        super().__init__()
155
156
        self.tp_size = tp_size

Liangsheng Yin's avatar
Liangsheng Yin committed
157
        self.gate_up_proj = MergedColumnParallelLinear(
158
159
160
161
162
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("gate_up_proj", prefix),
163
164
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
165
166
167
168
169
170
171
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
172
            prefix=add_prefix("down_proj", prefix),
173
174
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
175
176
177
178
179
180
181
182
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

183
184
185
186
    def forward(self, x, forward_batch=None):
        if (self.tp_size == 1) and x.shape[0] == 0:
            return x

Liangsheng Yin's avatar
Liangsheng Yin committed
187
188
189
190
191
192
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


Ke Bao's avatar
Ke Bao committed
193
class MoEGate(nn.Module):
194
195
196
197
198
    def __init__(
        self,
        config,
        prefix: str = "",
    ):
Ke Bao's avatar
Ke Bao committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        super().__init__()
        self.weight = nn.Parameter(
            torch.empty((config.n_routed_experts, config.hidden_size))
        )
        if config.topk_method == "noaux_tc":
            self.e_score_correction_bias = nn.Parameter(
                torch.empty((config.n_routed_experts))
            )
        else:
            self.e_score_correction_bias = None

    def forward(self, hidden_states):
        logits = F.linear(hidden_states, self.weight, None)
        return logits


Liangsheng Yin's avatar
Liangsheng Yin committed
215
216
217
218
219
class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
fzyzcjy's avatar
fzyzcjy committed
220
        layer_id: int,
Liangsheng Yin's avatar
Liangsheng Yin committed
221
        quant_config: Optional[QuantizationConfig] = None,
222
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
223
224
225
226
227
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
        self.n_shared_experts = config.n_shared_experts
228
229
230
231
232
        self.num_fused_shared_experts = (
            0
            if global_server_args_dict["disable_shared_experts_fusion"]
            else config.n_shared_experts
        )
233
        self.config = config
fzyzcjy's avatar
fzyzcjy committed
234
        self.layer_id = layer_id
235

Liangsheng Yin's avatar
Liangsheng Yin committed
236
237
238
239
240
241
242
243
244
245
246
247
        if self.tp_size > config.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.n_routed_experts}."
            )

        if config.hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

248
        self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
Ke Bao's avatar
Ke Bao committed
249

250
        self.experts = get_moe_impl_class()(
251
            num_experts=config.n_routed_experts
252
            + self.num_fused_shared_experts
253
            + global_server_args_dict["ep_num_redundant_experts"],
254
            top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
255
256
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
fzyzcjy's avatar
fzyzcjy committed
257
            layer_id=self.layer_id,
258
259
260
261
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
262
            num_fused_shared_experts=self.num_fused_shared_experts,
263
264
            topk_group=config.topk_group,
            correction_bias=self.gate.e_score_correction_bias,
265
            routed_scaling_factor=self.routed_scaling_factor,
266
267
268
269
270
271
272
            prefix=add_prefix("experts", prefix),
            **(
                dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
                if global_server_args_dict["enable_deepep_moe"]
                else {}
            ),
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
273

274
        if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
275
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
276
            # disable tp for shared experts when enable deepep moe
277
278
279
280
281
282
283
284
285
286
287
288
289
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                prefix=add_prefix("shared_experts", prefix),
                **(
                    dict(tp_rank=0, tp_size=1)
                    if global_server_args_dict["enable_deepep_moe"]
                    else {}
                ),
            )
290

291
292
        self.top_k = config.num_experts_per_tok

293
        if global_server_args_dict["enable_deepep_moe"]:
294
295
            # TODO: we will support tp < ep in the future
            self.ep_size = get_tensor_model_parallel_world_size()
296
297
298
299
            self.num_experts = (
                config.n_routed_experts
                + global_server_args_dict["ep_num_redundant_experts"]
            )
300
301
302
303
304
305
306
307
308
            self.renormalize = config.norm_topk_prob
            self.topk_group = config.topk_group
            self.num_expert_group = config.n_group
            self.correction_bias = (
                self.gate.e_score_correction_bias.data
                if self.gate.e_score_correction_bias is not None
                else None
            )

309
            self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
310
311
312
                group=parallel_state.get_tp_group().device_group,
                router_topk=self.top_k,
                permute_fusion=True,
313
                num_experts=self.num_experts,
314
                num_local_experts=config.n_routed_experts // self.tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
315
                hidden_size=config.hidden_size,
316
                params_dtype=config.torch_dtype,
317
                deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
318
                async_finish=True,
319
                return_recv_hook=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
320
321
            )

322
        self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
323

324
325
326
327
328
329
330
    def get_moe_weights(self):
        return [
            x.data
            for name, x in self.experts.named_parameters()
            if name not in ["correction_bias"]
        ]

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    def forward(
        self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
    ) -> torch.Tensor:
        if not self._enable_deepep_moe:
            return self.forward_normal(hidden_states)
        else:
            return self.forward_deepep(hidden_states, forward_batch)

    def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
        shared_output = self._forward_shared_experts(hidden_states)
        # router_logits: (num_tokens, n_experts)
        router_logits = self.gate(hidden_states)
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
346
347
        if not _is_cuda:
            final_hidden_states *= self.routed_scaling_factor
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

    def forward_deepep(
        self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        forward_mode = forward_batch.forward_mode
        shared_output = None
        if is_non_idle_and_non_empty(forward_mode, hidden_states):
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
            shared_output = self._forward_shared_experts(hidden_states)
            topk_weights, topk_idx = select_experts(
                hidden_states=hidden_states,
                router_logits=router_logits,
                top_k=self.top_k,
                use_grouped_topk=True,
                renormalize=self.renormalize,
                topk_group=self.topk_group,
                num_expert_group=self.num_expert_group,
371
                num_fused_shared_experts=self.num_fused_shared_experts,
372
373
374
                correction_bias=self.correction_bias,
                routed_scaling_factor=self.routed_scaling_factor,
                num_token_non_padded=forward_batch.num_token_non_padded,
375
376
377
                expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
                    layer_id=self.layer_id,
                ),
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
            )
        else:
            topk_idx = torch.full(
                (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
            )
            topk_weights = torch.empty(
                (0, self.top_k), dtype=torch.float32, device=hidden_states.device
            )
        if self.ep_size > 1:
            # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
            (
                hidden_states,
                topk_idx,
                topk_weights,
                reorder_topk_ids,
                num_recv_tokens_per_expert,
                seg_indptr,
                masked_m,
                expected_m,
            ) = self.deepep_dispatcher.dispatch(
                hidden_states=hidden_states,
                topk_idx=topk_idx,
                topk_weights=topk_weights,
                forward_mode=forward_mode,
            )
        final_hidden_states = self.experts(
            hidden_states=hidden_states,
            topk_idx=topk_idx,
            topk_weights=topk_weights,
            reorder_topk_ids=reorder_topk_ids,
            seg_indptr=seg_indptr,
            masked_m=masked_m,
            expected_m=expected_m,
            num_recv_tokens_per_expert=num_recv_tokens_per_expert,
            forward_mode=forward_mode,
        )
        if self.ep_size > 1:
            final_hidden_states = self.deepep_dispatcher.combine(
                hidden_states=final_hidden_states,
                topk_idx=topk_idx,
                topk_weights=topk_weights,
                forward_mode=forward_mode,
            )

        if shared_output is not None:
423
424
425
426
427
            x = shared_output
            x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
            final_hidden_states = x
        else:
            final_hidden_states *= self.routed_scaling_factor
428
429
430
431

        return final_hidden_states

    def _forward_shared_experts(self, hidden_states):
432
        if self.num_fused_shared_experts == 0:
433
434
435
436
            return self.shared_experts(hidden_states)
        else:
            return None

437
    def op_gate(self, state):
438
        if is_non_idle_and_non_empty(
439
            state.forward_batch.forward_mode, state.hidden_states_mlp_input
440
        ):
441
            # router_logits: (num_tokens, n_experts)
442
            state.router_logits = self.gate(state.hidden_states_mlp_input)
443
        else:
444
            state.router_logits = None
445

446
    def op_shared_experts(self, state):
447
        hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
448
        if (self.num_fused_shared_experts == 0) and is_non_idle_and_non_empty(
449
            state.forward_batch.forward_mode, hidden_states_mlp_input
450
        ):
451
            state.shared_output = self.shared_experts(hidden_states_mlp_input)
452
        else:
453
            state.shared_output = None
454

455
    def op_select_experts(self, state):
456
        router_logits = state.pop("router_logits")
457
458
        hidden_states = state.hidden_states_mlp_input

459
        if router_logits is not None:
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
            with get_global_expert_distribution_recorder().with_current_layer(
                self.layer_id
            ):
                state.topk_weights_local, state.topk_idx_local = select_experts(
                    hidden_states=hidden_states,
                    router_logits=router_logits,
                    top_k=self.top_k,
                    use_grouped_topk=True,
                    renormalize=self.renormalize,
                    topk_group=self.topk_group,
                    num_expert_group=self.num_expert_group,
                    num_fused_shared_experts=self.num_fused_shared_experts,
                    correction_bias=self.correction_bias,
                    routed_scaling_factor=self.routed_scaling_factor,
                    num_token_non_padded=state.forward_batch.num_token_non_padded,
                    expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
                        layer_id=self.layer_id,
                    ),
                )
479
480
481
482
483
484
485
        else:
            state.topk_idx_local = torch.full(
                (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
            )
            state.topk_weights_local = torch.empty(
                (0, self.top_k), dtype=torch.float32, device=hidden_states.device
            )
486

487
    def op_dispatch_a(self, state):
488
        if self.ep_size > 1:
489
            # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
490
            self.deepep_dispatcher.dispatch_a(
491
                hidden_states=state.hidden_states_mlp_input,
492
493
494
                topk_idx=state.pop("topk_idx_local"),
                topk_weights=state.pop("topk_weights_local"),
                forward_mode=state.forward_batch.forward_mode,
495
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
496
            )
497

498
    def op_dispatch_b(self, state):
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
        if self.ep_size > 1:
            with get_global_expert_distribution_recorder().with_current_layer(
                self.layer_id
            ):
                (
                    state.hidden_states_experts_input,
                    state.topk_idx_dispatched,
                    state.topk_weights_dispatched,
                    state.reorder_topk_ids,
                    state.num_recv_tokens_per_expert,
                    state.seg_indptr,
                    state.masked_m,
                    state.expected_m,
                ) = self.deepep_dispatcher.dispatch_b(
                    tbo_subbatch_index=state.get("tbo_subbatch_index"),
                )
515
516

    def op_experts(self, state):
517
518
519
520
521
522
523
524
525
526
527
        state.hidden_states_experts_output = self.experts(
            hidden_states=state.pop("hidden_states_experts_input"),
            topk_idx=state.topk_idx_dispatched,
            topk_weights=state.topk_weights_dispatched,
            reorder_topk_ids=state.pop("reorder_topk_ids"),
            seg_indptr=state.pop("seg_indptr"),
            masked_m=state.pop("masked_m"),
            expected_m=state.pop("expected_m"),
            num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
            forward_mode=state.forward_batch.forward_mode,
        )
528

529
    def op_combine_a(self, state):
530
        if self.ep_size > 1:
531
            self.deepep_dispatcher.combine_a(
532
                hidden_states=state.pop("hidden_states_experts_output"),
533
534
535
                topk_idx=state.pop("topk_idx_dispatched"),
                topk_weights=state.pop("topk_weights_dispatched"),
                forward_mode=state.forward_batch.forward_mode,
536
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
537
            )
538

539
    def op_combine_b(self, state):
540
541
542
543
        if self.ep_size > 1:
            state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
            )
544
545

    def op_output(self, state):
546
        final_hidden_states = state.pop("hidden_states_after_combine")
547
548
549
550
551
552
553

        if (shared_output := state.pop("shared_output")) is not None:
            x = shared_output
            x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
            final_hidden_states = x
        else:
            final_hidden_states *= self.routed_scaling_factor
Liangsheng Yin's avatar
Liangsheng Yin committed
554

555
        state.hidden_states_mlp_output = final_hidden_states
556

Liangsheng Yin's avatar
Liangsheng Yin committed
557
558
559
560
561
562
563
564
565

def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math

    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
class DeepseekV2AttentionMLA(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
582
583
        reduce_results: bool = True,
        layer_id: int = None,
584
        prefix: str = "",
585
        alt_stream: Optional[torch.cuda.Stream] = None,
586
587
588
589
590
591
592
593
594
595
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
Lianmin Zheng's avatar
Lianmin Zheng committed
596
597
598
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

599
        self.num_heads = num_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
600
601
        assert num_heads % attn_tp_size == 0
        self.num_local_heads = num_heads // attn_tp_size
602
603
604
605
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

Lianmin Zheng's avatar
Lianmin Zheng committed
606
607
        # For tensor parallel attention
        if self.q_lora_rank is not None:
608
            self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
Ke Bao's avatar
Ke Bao committed
609
                self.hidden_size,
610
                self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
611
612
                bias=False,
                quant_config=quant_config,
613
                prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix),
614
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
615
616
617
618
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
Ke Bao's avatar
Ke Bao committed
619
620
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
621
622
623
                prefix=add_prefix("q_b_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
Ke Bao's avatar
Ke Bao committed
624
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
625
626
        else:
            self.q_proj = ColumnParallelLinear(
627
                self.hidden_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
628
                self.num_heads * self.qk_head_dim,
629
630
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
631
632
633
                prefix=add_prefix("q_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
634
            )
635
636
637
638
639
640
641
642
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=add_prefix("kv_a_proj_with_mqa", prefix),
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("kv_b_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
        # O projection.
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=add_prefix("o_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
663
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
Ke Bao's avatar
Ke Bao committed
664
665
666
667

        if rope_scaling:
            rope_scaling["rope_type"] = "deepseek_yarn"

668
        self.rotary_emb = get_rope(
669
670
671
672
673
674
675
676
677
678
679
680
681
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
        )

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale
Ke Bao's avatar
Ke Bao committed
682
683
        else:
            self.rotary_emb.forward = self.rotary_emb.forward_native
684

685
        self.attn_mqa = RadixAttention(
686
687
688
689
690
691
            self.num_local_heads,
            self.kv_lora_rank + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=1,
            layer_id=layer_id,
            v_head_dim=self.kv_lora_rank,
692
            quant_config=quant_config,
693
            prefix=add_prefix("attn_mqa", prefix),
694
695
        )

696
697
698
699
700
701
702
        self.attn_mha = RadixAttention(
            self.num_local_heads,
            self.qk_nope_head_dim + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            layer_id=layer_id,
            v_head_dim=self.v_head_dim,
703
            quant_config=quant_config,
704
            prefix=add_prefix("attn_mha", prefix),
705
706
        )

707
        self.alt_stream = alt_stream
708
        self.attn_mha.kv_b_proj = None
709

Ke Bao's avatar
Ke Bao committed
710
711
        self.w_kc = None
        self.w_vc = None
712
        self.w_scale = 1.0
713

714
715
716
717
        self.w_scale_k = None
        self.w_scale_v = None
        self.use_deep_gemm_bmm = False

Lianmin Zheng's avatar
Lianmin Zheng committed
718
719
720
        self.flashinfer_mla_disable_ragged = global_server_args_dict[
            "flashinfer_mla_disable_ragged"
        ]
721
722
723
        self.disable_chunked_prefix_cache = global_server_args_dict[
            "disable_chunked_prefix_cache"
        ]
724
        self.attention_backend = global_server_args_dict["attention_backend"]
725
726
727
        self.rocm_fused_decode_mla = get_bool_env_var(
            "SGLANG_ROCM_FUSED_DECODE_MLA", "false"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
728

729
        # TODO: Design a finer way to determine the threshold
730
731
732
        self.chunked_prefix_cache_threshold = get_int_env_var(
            "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
        )
733
734
735
736

    def dispatch_attn_forward_method(
        self, forward_batch: ForwardBatch
    ) -> AttnForwardMethod:
737
738
739
740
741
742
743
744
745
746
747
748
        def _dispatch_mla_subtype():
            if _is_hip:
                if (
                    self.rocm_fused_decode_mla
                    and forward_batch.forward_mode.is_decode()
                ):
                    return AttnForwardMethod.MLA_FUSED_ROPE
                else:
                    return AttnForwardMethod.MLA
            else:
                return AttnForwardMethod.MLA

749
        if self.attention_backend == "flashinfer":
Lianmin Zheng's avatar
Lianmin Zheng committed
750
            # Flashinfer MLA: Do not absorb when enabling ragged prefill
751
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
752
753
754
755
                not self.flashinfer_mla_disable_ragged
                and forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
756
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
757
758
759
            ):
                return AttnForwardMethod.MHA
            else:
760
                return _dispatch_mla_subtype()
761
        elif self.attention_backend == "fa3":
762
            # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
763
764
            if forward_batch.extend_prefix_lens_cpu is not None:
                sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
765
766
767
768
769
            if (
                forward_batch.forward_mode.is_extend()
                and not self.disable_chunked_prefix_cache
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
770
771
772
773
                and (
                    sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
                    or sum_extend_prefix_lens == 0
                )
774
775
776
            ):
                return AttnForwardMethod.MHA_CHUNKED_KV
            else:
777
                return _dispatch_mla_subtype()
778
779
780
781
782
783
784
785
786
        elif self.attention_backend == "aiter":
            if (
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
            ):
                return AttnForwardMethod.MHA
            else:
                return AttnForwardMethod.MLA
Lianmin Zheng's avatar
Lianmin Zheng committed
787
788
        else:
            # Triton: Use normal computation for prefill and use weight absorption for extend/decode
789
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
790
791
792
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
793
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
794
795
796
            ):
                return AttnForwardMethod.MHA
            else:
797
                return _dispatch_mla_subtype()
Lianmin Zheng's avatar
Lianmin Zheng committed
798

799
800
801
802
803
804
805
806
807
808
809
810
811
    def op_prepare(self, state):
        state.attn_intermediate_state = self.forward_prepare(
            positions=state.positions,
            hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
        )

    def op_core(self, state):
        state.hidden_states_after_attn = self.forward_core(
            state.pop("attn_intermediate_state")
        )

812
813
814
815
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
816
        forward_batch: ForwardBatch,
817
        zero_allocator: BumpAllocator,
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
    ):
        s = self.forward_prepare(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
            zero_allocator=zero_allocator,
        )
        return self.forward_core(s)

    def forward_prepare(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        zero_allocator: BumpAllocator,
    ):
834
835
836
        if self.attn_mha.kv_b_proj is None:
            self.attn_mha.kv_b_proj = self.kv_b_proj

Lianmin Zheng's avatar
Lianmin Zheng committed
837
838
839
840
        if hidden_states.shape[0] == 0:
            assert (
                not self.o_proj.reduce_results
            ), "short-circuiting allreduce will lead to hangs"
841
            return hidden_states, None, forward_batch, None
842

843
844
845
        attn_forward_method = self.dispatch_attn_forward_method(forward_batch)

        if attn_forward_method == AttnForwardMethod.MHA:
846
847
848
            inner_state = self.forward_normal_prepare(
                positions, hidden_states, forward_batch, zero_allocator
            )
849
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
850
851
            inner_state = self.forward_normal_chunked_kv_prepare(
                positions, hidden_states, forward_batch, zero_allocator
852
            )
853
        elif attn_forward_method == AttnForwardMethod.MLA:
854
            inner_state = self.forward_absorb_prepare(
855
856
857
                positions, hidden_states, forward_batch, zero_allocator
            )
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
858
859
            inner_state = self.forward_absorb_fused_mla_rope_prepare(
                positions, hidden_states, forward_batch, zero_allocator
860
            )
861
        else:
862
            raise NotImplementedError
863
        return None, attn_forward_method, forward_batch, inner_state
864

865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
    def forward_core(self, intermediate_state):
        hidden_states, attn_forward_method, forward_batch, inner_state = (
            intermediate_state
        )
        if inner_state is None:
            return hidden_states

        if attn_forward_method == AttnForwardMethod.MHA:
            return self.forward_normal_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
            return self.forward_normal_chunked_kv_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MLA:
            return self.forward_absorb_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
            return self.forward_absorb_fused_mla_rope_core(*inner_state)
        else:
            raise NotImplementedError

    def forward_normal_prepare(
884
885
886
887
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
888
889
        zero_allocator: BumpAllocator,
    ):
890
        if self.q_lora_rank is not None:
891
892
893
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
894
895
896
897
898
899
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
900
901
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]

902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )
924
925
926
927

        return q, k, v, forward_batch

    def forward_normal_core(self, q, k, v, forward_batch):
928
929
930
931
932
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

933
    def forward_absorb_prepare(
934
935
936
937
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
938
        zero_allocator: BumpAllocator,
939
    ):
940
941
        from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode

942
        if self.q_lora_rank is not None:
943
944
945
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
946
947
948
            k_nope = latent_cache[..., : self.kv_lora_rank]

            # overlap qk norm
949
            if self.alt_stream is not None and get_is_capture_mode():
950
951
952
953
954
955
956
957
958
959
960
                current_stream = torch.cuda.current_stream()
                self.alt_stream.wait_stream(current_stream)
                q = self.q_a_layernorm(q)
                with torch.cuda.stream(self.alt_stream):
                    k_nope = self.kv_a_layernorm(k_nope)
                current_stream.wait_stream(self.alt_stream)
            else:
                q = self.q_a_layernorm(q)
                k_nope = self.kv_a_layernorm(k_nope)

            k_nope = k_nope.unsqueeze(1)
961
962
963
964
965
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
966
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
967
968
969
            k_nope = latent_cache[..., : self.kv_lora_rank]
            k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)

970
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
971
        k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
972

973
974
        if self.use_deep_gemm_bmm:
            q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
975
                per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
976
977
978
979
            )
            q_nope_out = q_nope.new_empty(
                (self.num_local_heads, aligned_m, self.kv_lora_rank)
            )
980
            deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
981
982
983
984
985
986
987
                (q_nope_val, q_nope_scale),
                (self.w_kc, self.w_scale_k),
                q_nope_out,
                masked_m,
                expected_m,
            )
            q_nope_out = q_nope_out[:, :expected_m, :]
988
989
        elif _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
990
991
992
993
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
994
        elif self.w_kc.dtype == torch.float8_e4m3fn:
995
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
996
                q_nope.transpose(0, 1),
997
                zero_allocator.allocate(1),
998
999
1000
1001
1002
1003
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
1004
1005

        q_nope_out = q_nope_out.transpose(0, 1)
1006
1007
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

1008
1009
1010
1011
1012
        return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator

    def forward_absorb_core(
        self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
    ):
1013
1014
1015
1016
1017
        if (
            self.attention_backend == "fa3"
            or self.attention_backend == "flashinfer"
            or self.attention_backend == "cutlass_mla"
        ):
1018
            attn_output = self.attn_mqa(
Ke Bao's avatar
Ke Bao committed
1019
                q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
1020
1021
1022
            )
        else:
            q = torch.cat([q_nope_out, q_pe], dim=-1)
Ke Bao's avatar
Ke Bao committed
1023
            k = torch.cat([k_nope, k_pe], dim=-1)
1024
            attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
1025
1026
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

1027
1028
        if self.use_deep_gemm_bmm:
            attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
1029
1030
                per_token_group_quant_mla_deep_gemm_masked_fp8(
                    attn_output.transpose(0, 1)
1031
1032
1033
1034
1035
                )
            )
            attn_bmm_output = attn_output.new_empty(
                (self.num_local_heads, aligned_m, self.v_head_dim)
            )
1036
            deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1037
1038
1039
1040
1041
1042
1043
                (attn_output_val, attn_output_scale),
                (self.w_vc, self.w_scale_v),
                attn_bmm_output,
                masked_m,
                expected_m,
            )
            attn_bmm_output = attn_bmm_output[:, :expected_m, :]
1044
1045
        elif _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1046
1047
1048
1049
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
1050
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1051
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
1052
                attn_output.transpose(0, 1),
1053
                zero_allocator.allocate(1),
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
        output, _ = self.o_proj(attn_output)

        return output

1069
    def forward_absorb_fused_mla_rope_prepare(
1070
1071
1072
1073
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1074
        zero_allocator: BumpAllocator,
1075
    ):
1076
1077
1078
1079
1080
1081
1082
1083
        enable_rope_fusion = (
            os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
        )
        q_len = hidden_states.shape[0]
        q_input = hidden_states.new_empty(
            q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
        )
        if self.q_lora_rank is not None:
1084
1085
1086
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1087
1088
1089
1090
1091
1092
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1093
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1094
1095
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

1096
1097
        if _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1098
1099
1100
1101
1102
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_kc.dtype == torch.float8_e4m3fn:
1103
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1104
1105
1106
                q_nope.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
        q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
        v_input = latent_cache[..., : self.kv_lora_rank]
        v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
        k_input[..., : self.kv_lora_rank] = v_input

        if not enable_rope_fusion:
            k_pe = k_input[..., self.kv_lora_rank :]
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
            q_input[..., self.kv_lora_rank :] = q_pe
            k_input[..., self.kv_lora_rank :] = k_pe
            k_pe_output = None
        else:
            k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])

        q_input[..., self.kv_lora_rank :] = q_pe

        # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        # Use Fused ROPE with use_rope=OFF.
        attn_output = torch.empty(
            (q_len, self.num_local_heads, self.kv_lora_rank),
            dtype=q.dtype,
            device=q.device,
        )
        attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
            forward_batch.attn_backend.forward_metadata
        )
        cos_sin_cache = self.rotary_emb.cos_sin_cache
        num_kv_split = forward_batch.attn_backend.num_kv_splits
        sm_scale = self.attn_mqa.scaling
        if attn_logits is None:
            attn_logits = torch.empty(
                (
                    forward_batch.batch_size,
                    self.num_local_heads,
                    num_kv_split,
                    self.kv_lora_rank + 1,
                ),
                dtype=torch.float32,
                device=q.device,
            )

        # save current latent cache.
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mqa, forward_batch.out_cache_loc, k_input, None
        )
        key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
            self.attn_mqa.layer_id
        )
        val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]

1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
        return (
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            enable_rope_fusion,
            k_input,
            forward_batch,
            zero_allocator,
        )

    def forward_absorb_fused_mla_rope_core(
        self,
        q_input,
        key_cache_buf,
        val_cache_buf,
        attn_output,
        kv_indptr,
        kv_indices,
        k_pe_output,
        cos_sin_cache,
        positions,
        attn_logits,
        num_kv_split,
        sm_scale,
        enable_rope_fusion,
        k_input,
        forward_batch,
        zero_allocator,
    ):
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
        decode_attention_fwd_grouped_rope(
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            self.kv_lora_rank,
            self.rotary_emb.rotary_dim,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            logit_cap=self.attn_mqa.logit_cap,
            use_rope=enable_rope_fusion,
            is_neox_style=self.rotary_emb.is_neox_style,
        )

        if enable_rope_fusion:
            k_input[..., self.kv_lora_rank :] = k_pe_output
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mqa, forward_batch.out_cache_loc, k_input, None
            )

        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

1230
1231
        if _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1232
1233
1234
1235
1236
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1237
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1238
1239
1240
                attn_output.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1252
1253
1254
1255
        output, _ = self.o_proj(attn_output)

        return output

1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
    def _chunked_prefix_attn_mha(
        self,
        q: torch.Tensor,
        accum_output: torch.Tensor,
        accum_lse: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:

        assert forward_batch.num_prefix_chunks is not None
        for i in range(forward_batch.num_prefix_chunks):
            forward_batch.set_prefix_chunk_idx(i)

            # Fetch latent cache from memory pool with precomputed chunked kv indices
            latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
                self.attn_mha.layer_id
            )
            latent_cache = latent_cache_buf[
                forward_batch.prefix_chunk_kv_indices[i]
            ].contiguous()

            kv_a_normed, k_pe = latent_cache.split(
                [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
            )
            kv_a_normed = kv_a_normed.squeeze(1).contiguous()
            kv = self.kv_b_proj(kv_a_normed)[0]
            kv = kv.view(
                -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
            )
            v = kv[..., self.qk_nope_head_dim :]
            k_nope = kv[..., : self.qk_nope_head_dim]

            k = torch.empty(
                (
                    k_nope.shape[0],
                    self.num_local_heads,
                    self.qk_nope_head_dim + self.qk_rope_head_dim,
                ),
                dtype=v.dtype,
                device=v.device,
            )
            k[..., : self.qk_nope_head_dim] = k_nope
            k[..., self.qk_nope_head_dim :] = k_pe

            output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
            lse = torch.transpose(lse, 0, 1).contiguous()
            tmp_output = torch.empty_like(accum_output)
            tmp_lse = torch.empty_like(accum_lse)
            merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
            accum_output, accum_lse = tmp_output, tmp_lse

        return accum_output

1308
    def forward_normal_chunked_kv_prepare(
1309
1310
1311
1312
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1313
1314
        zero_allocator: BumpAllocator,
    ):
1315
1316
1317
1318
1319
1320
1321
1322
        # In normal mha, the k and v tensors will become overly large when the prefix length is long.
        # To avoid this, we split the kv cache into chunks and process them one after another.
        # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
        # The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
        # will be helpful for understanding the purpose of this function.

        # First do normal mha forward to get output for extended part
        if self.q_lora_rank is not None:
1323
1324
1325
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1326
1327
1328
1329
1330
1331
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1332
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]

        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )

1357
1358
1359
        return q, k, v, forward_batch

    def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
        # Do mha for extended part without prefix
        forward_batch.set_attn_attend_prefix_cache(False)
        attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        lse = torch.transpose(lse, 0, 1).contiguous()

        # Do mha attention with chunked prefix cache if there are any sequence with prefix
        if any(forward_batch.extend_prefix_lens_cpu):
            # Only initialize the info once
            if forward_batch.num_prefix_chunks is None:
                forward_batch.prepare_chunked_prefix_cache_info(q.device)

            forward_batch.set_attn_attend_prefix_cache(True)
            attn_output = self._chunked_prefix_attn_mha(
                q=q,
                accum_output=attn_output,
                accum_lse=lse,
                forward_batch=forward_batch,
            )

        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

1383

Liangsheng Yin's avatar
Liangsheng Yin committed
1384
1385
1386
1387
1388
1389
1390
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_id: int,
        quant_config: Optional[QuantizationConfig] = None,
1391
        is_nextn: bool = False,
1392
        prefix: str = "",
1393
        alt_stream: Optional[torch.cuda.Stream] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1394
1395
1396
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
1397
        self.config = config
Liangsheng Yin's avatar
Liangsheng Yin committed
1398
1399
1400
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
Lianmin Zheng's avatar
Lianmin Zheng committed
1401
        self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
1402
        self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
Lianmin Zheng's avatar
Lianmin Zheng committed
1403
        self.layer_id = layer_id
1404
        self.is_nextn = is_nextn
Baizhou Zhang's avatar
Baizhou Zhang committed
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
        self.self_attn = DeepseekV2AttentionMLA(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=(
                config.q_lora_rank if hasattr(config, "q_lora_rank") else None
            ),
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            layer_id=layer_id,
            reduce_results=False,
            prefix=add_prefix("self_attn", prefix),
1423
            alt_stream=alt_stream,
Baizhou Zhang's avatar
Baizhou Zhang committed
1424
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1425

1426
1427
1428
1429
1430
1431
1432
1433
        self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
        is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)

        self.layer_scatter_modes = LayerScatterModes.init_new(
            layer_id=layer_id,
            num_layers=config.num_hidden_layers,
            is_layer_sparse=self.is_layer_sparse,
            is_previous_layer_sparse=is_previous_layer_sparse,
1434
1435
        )

1436
        if self.is_layer_sparse:
1437
1438
1439
1440
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=add_prefix("mlp", prefix),
fzyzcjy's avatar
fzyzcjy committed
1441
                layer_id=self.layer_id,
1442
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
1443
        else:
1444
            if enable_moe_dense_fully_dp():
1445
1446
1447
                mlp_tp_rank, mlp_tp_size = 0, 1
            else:
                mlp_tp_rank, mlp_tp_size = None, None
Liangsheng Yin's avatar
Liangsheng Yin committed
1448
1449
1450
1451
1452
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1453
                prefix=add_prefix("mlp", prefix),
1454
1455
                tp_rank=mlp_tp_rank,
                tp_size=mlp_tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
1456
            )
1457

Liangsheng Yin's avatar
Liangsheng Yin committed
1458
1459
1460
1461
1462
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

1463
1464
1465
1466
        self.layer_communicator = LayerCommunicator(
            layer_scatter_modes=self.layer_scatter_modes,
            input_layernorm=self.input_layernorm,
            post_attention_layernorm=self.post_attention_layernorm,
1467
        )
1468
1469
1470
1471
1472
1473

    def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
        return is_nextn or (
            self.config.n_routed_experts is not None
            and layer_id >= self.config.first_k_dense_replace
            and layer_id % self.config.moe_layer_freq == 0
1474
1475
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
1476
1477
1478
1479
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1480
        forward_batch: ForwardBatch,
Liangsheng Yin's avatar
Liangsheng Yin committed
1481
        residual: Optional[torch.Tensor],
1482
        zero_allocator: BumpAllocator,
Liangsheng Yin's avatar
Liangsheng Yin committed
1483
    ) -> torch.Tensor:
1484
1485
        hidden_states, residual = self.layer_communicator.prepare_attn(
            hidden_states, residual, forward_batch
1486
1487
        )

1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
            zero_allocator=zero_allocator,
        )

        hidden_states, residual = self.layer_communicator.prepare_mlp(
            hidden_states, residual, forward_batch
        )

        hidden_states = self.mlp(hidden_states, forward_batch)

        hidden_states, residual = self.layer_communicator.postprocess_layer(
            hidden_states, residual, forward_batch
        )

1505
1506
1507
1508
1509
        if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
            # NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
            # See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
            hidden_states = hidden_states.clone()

1510
1511
        return hidden_states, residual

1512
1513
1514
1515
1516
1517
1518
1519
    def op_comm_prepare_attn(
        self,
        state,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
        zero_allocator: BumpAllocator,
1520
        tbo_subbatch_index: Optional[int] = None,
1521
1522
    ):
        state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
fzyzcjy's avatar
fzyzcjy committed
1523
            self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
1524
1525
1526
1527
1528
1529
        )
        state.update(
            dict(
                forward_batch=forward_batch,
                positions=positions,
                zero_allocator=zero_allocator,
1530
                tbo_subbatch_index=tbo_subbatch_index,
1531
            )
1532
        )
1533

1534
1535
1536
1537
1538
1539
1540
    def op_comm_prepare_mlp(self, state):
        state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
            self.layer_communicator.prepare_mlp(
                state.pop("hidden_states_after_attn"),
                state.pop("residual_after_input_ln"),
                state.forward_batch,
            )
1541
        )
1542

1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
    def op_mlp(self, state):
        hidden_states = state.pop("hidden_states_mlp_input")
        if not (
            enable_moe_dense_fully_dp()
            and (not self.is_layer_sparse)
            and hidden_states.shape[0] == 0
        ):
            state.hidden_states_mlp_output = self.mlp(
                hidden_states, state.forward_batch.forward_mode
            )
        else:
            state.hidden_states_mlp_output = hidden_states
1555

1556
    def op_comm_postprocess_layer(self, state):
1557
        hidden_states, residual = self.layer_communicator.postprocess_layer(
1558
1559
1560
            state.pop("hidden_states_mlp_output"),
            state.pop("residual_after_comm_pre_mlp"),
            state.forward_batch,
1561
        )
1562

1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
        output = dict(
            positions=state.positions,
            hidden_states=hidden_states,
            residual=residual,
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
            tbo_subbatch_index=state.tbo_subbatch_index,
        )

        state.clear(
            expect_keys={
                "positions",
                "forward_batch",
                "zero_allocator",
                "tbo_subbatch_index",
            }
        )
        return output
1581

Liangsheng Yin's avatar
Liangsheng Yin committed
1582
1583
1584
1585
1586
1587
1588
1589

class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1590
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1591
1592
1593
1594
    ) -> None:
        super().__init__()
        self.padding_id = config.pad_token_id
        self.vocab_size = config.vocab_size
1595
        self.first_k_dense_replace = config.first_k_dense_replace
Liangsheng Yin's avatar
Liangsheng Yin committed
1596
1597
1598
1599

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
Ke Bao's avatar
Ke Bao committed
1600
            enable_tp=not global_server_args_dict["enable_dp_attention"],
Liangsheng Yin's avatar
Liangsheng Yin committed
1601
        )
1602
        self.alt_stream = torch.cuda.Stream() if _is_cuda else None
Liangsheng Yin's avatar
Liangsheng Yin committed
1603
1604
1605
1606
1607
1608
        self.layers = nn.ModuleList(
            [
                DeepseekV2DecoderLayer(
                    config,
                    layer_id,
                    quant_config=quant_config,
1609
                    prefix=add_prefix(f"layers.{layer_id}", prefix),
1610
                    alt_stream=self.alt_stream,
Liangsheng Yin's avatar
Liangsheng Yin committed
1611
1612
1613
1614
1615
1616
                )
                for layer_id in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

1617
        self.dp_size = get_local_attention_dp_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1618

1619
1620
1621
    def get_input_embeddings(self) -> torch.Tensor:
        return self.embed_tokens

Liangsheng Yin's avatar
Liangsheng Yin committed
1622
1623
1624
1625
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1626
        forward_batch: ForwardBatch,
1627
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1628
    ) -> torch.Tensor:
1629
1630
        total_num_layers = len(self.layers)
        device = input_embeds.device if input_embeds is not None else input_ids.device
1631
        zero_allocator = BumpAllocator(
1632
            buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
1633
            dtype=torch.float32,
1634
            device=device,
1635
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1636

1637
1638
1639
1640
1641
        if input_embeds is None:
            hidden_states = self.embed_tokens(input_ids)
        else:
            hidden_states = input_embeds

Liangsheng Yin's avatar
Liangsheng Yin committed
1642
        residual = None
1643
1644
1645
1646
1647
1648
1649

        normal_num_layers = (
            self.first_k_dense_replace
            if forward_batch.can_run_tbo
            else total_num_layers
        )
        for i in range(normal_num_layers):
1650
1651
1652
1653
1654
            with get_global_expert_distribution_recorder().with_current_layer(i):
                layer = self.layers[i]
                hidden_states, residual = layer(
                    positions, hidden_states, forward_batch, residual, zero_allocator
                )
1655
1656
1657
1658
1659
1660
1661
1662
1663

        if normal_num_layers != total_num_layers:
            hidden_states, residual = model_forward_maybe_tbo(
                layers=self.layers[normal_num_layers:],
                enable_tbo=True,
                positions=positions,
                forward_batch=forward_batch,
                hidden_states=hidden_states,
                residual=residual,
1664
1665
1666
                input_data_scatter_mode=self.layers[
                    normal_num_layers - 1
                ].layer_scatter_modes.layer_output_mode,
1667
1668
1669
                zero_allocator=zero_allocator,
            )

Ke Bao's avatar
Ke Bao committed
1670
        if not forward_batch.forward_mode.is_idle():
1671
1672
1673
1674
            if residual is None:
                hidden_states = self.norm(hidden_states)
            else:
                hidden_states, _ = self.norm(hidden_states, residual)
Liangsheng Yin's avatar
Liangsheng Yin committed
1675
1676
1677
1678
1679
1680
1681
1682
1683
        return hidden_states


class DeepseekV2ForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1684
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1685
1686
1687
    ) -> None:
        super().__init__()
        self.config = config
1688
        self.tp_size = get_tensor_model_parallel_world_size()
Liangsheng Yin's avatar
Liangsheng Yin committed
1689
        self.quant_config = quant_config
1690
        self.determine_num_fused_shared_experts()
1691
1692
1693
1694
1695
1696
1697
1698
        self.model = DeepseekV2Model(
            config, quant_config, prefix=add_prefix("model", prefix)
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=add_prefix("lm_head", prefix),
1699
            use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
1700
1701
        )
        self.logits_processor = LogitsProcessor(config)
1702
        self.dp_size = get_local_attention_dp_size()
1703

1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
        self._routed_experts_weights_of_layer = LazyValue(
            lambda: {
                layer_id: layer.mlp.get_moe_weights()
                for layer_id, layer in enumerate(self.model.layers)
                if isinstance(layer.mlp, DeepseekV2MoE)
            }
        )

    @property
    def routed_experts_weights_of_layer(self):
        return self._routed_experts_weights_of_layer.value

1716
    def determine_num_fused_shared_experts(
1717
1718
        self, architecture: str = "DeepseekV3ForCausalLM"
    ):
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
        self.num_fused_shared_experts = 0
        if global_server_args_dict["disable_shared_experts_fusion"]:
            return

        # Only Deepseek V3/R1 can use shared experts fusion optimization now.
        disable_reason = None
        if (
            not _is_cuda
            or torch.cuda.get_device_capability("cuda") < (9, 0)
            or self.config.architectures[0] != architecture
            or self.config.n_routed_experts != 256
            or self.config.n_shared_experts != 1
        ):
            disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 90 can use shared experts fusion optimization."
        elif (
            global_server_args_dict["enable_deepep_moe"]
            or global_server_args_dict["enable_ep_moe"]
        ):
            disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."

        if disable_reason is not None:
            global_server_args_dict["disable_shared_experts_fusion"] = True
            log_info_on_rank0(
                logger,
                f"{disable_reason} Shared experts fusion optimization is disabled.",
            )
            return

        self.num_fused_shared_experts = self.config.n_shared_experts
1748

Mick's avatar
Mick committed
1749
1750
1751
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.embed_tokens

1752
    @torch.no_grad()
Liangsheng Yin's avatar
Liangsheng Yin committed
1753
1754
1755
1756
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1757
        forward_batch: ForwardBatch,
1758
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1759
    ) -> torch.Tensor:
1760
        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
Lianmin Zheng's avatar
Lianmin Zheng committed
1761

1762
1763
1764
        return self.logits_processor(
            input_ids, hidden_states, self.lm_head, forward_batch
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1765

1766
    def post_load_weights(self, is_nextn=False, weight_names=None):
inkcherry's avatar
inkcherry committed
1767
1768

        # Perform post-processing after loading weights
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
        if is_nextn:
            layer_ids = [self.config.num_hidden_layers]
        else:
            if weight_names is None:
                layer_ids = range(self.config.num_hidden_layers)
            else:
                layer_ids = set()
                for name in weight_names:
                    if "kv_b_proj" in name:
                        layer_id = int(name.split(".")[2])
1779
                        if layer_id < self.config.num_hidden_layers:
1780
1781
                            layer_ids.add(layer_id)

1782
1783
1784
1785
1786
1787
        for layer_id in layer_ids:
            self_attn = (
                self.model.layers[layer_id].self_attn
                if not is_nextn
                else self.model.decoder.self_attn
            )
Baizhou Zhang's avatar
Baizhou Zhang committed
1788
1789
1790
1791
1792
1793
1794
1795
            if hasattr(self_attn.kv_b_proj, "qweight"):
                # AWQ compatible
                if _is_cuda:
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                    ).T
inkcherry's avatar
inkcherry committed
1796
                else:
Baizhou Zhang's avatar
Baizhou Zhang committed
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                        0,
                        0,
                        0,
                    ).T
            else:
                w = self_attn.kv_b_proj.weight
            # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
            # This may affect the accuracy of fp8 model.
1809
1810
1811
1812
            # Fix deepseek v3 blockwise bmm by using deep_gemm
            use_deep_gemm_bmm = False
            model_dtype = torch.get_default_dtype()

Baizhou Zhang's avatar
Baizhou Zhang committed
1813
1814
1815
1816
            if w.dtype in (
                torch.float8_e4m3fn,
                torch.float8_e4m3fnuz,
            ):
1817
1818
1819
1820
                if (
                    hasattr(self.quant_config, "weight_block_size")
                    and self.quant_config.weight_block_size is not None
                ):
Baizhou Zhang's avatar
Baizhou Zhang committed
1821
                    weight_block_size = self.quant_config.weight_block_size
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
                    assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                    if _is_fp8_fnuz:
                        weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                            weight=w,
                            weight_scale=self_attn.kv_b_proj.weight_scale_inv,
                            input_scale=None,
                        )
                    else:
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv

                    if (
                        _is_cuda
                        and weight_block_size[0] == 128
                        and weight_block_size[1] == 128
                        and model_dtype == torch.bfloat16
                    ):
1839
1840
1841
1842
                        if (
                            deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
                            and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL
                            and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false")
1843
                        ):
1844
1845
                            block_scale = weight_scale
                            use_deep_gemm_bmm = True
1846
                        else:
1847
1848
1849
1850
1851
                            w = block_quant_dequant(
                                weight,
                                weight_scale,
                                weight_block_size,
                                model_dtype,
1852
                            )
1853
1854
1855
1856
1857
                    else:
                        w, scale = block_quant_to_tensor_quant(
                            weight, weight_scale, weight_block_size
                        )
                        self_attn.w_scale = scale
Baizhou Zhang's avatar
Baizhou Zhang committed
1858
                else:
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
                    if _is_fp8_fnuz:
                        weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                            weight=w,
                            weight_scale=self_attn.kv_b_proj.weight_scale,
                            input_scale=None,
                        )
                    else:
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale

Baizhou Zhang's avatar
Baizhou Zhang committed
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
                    w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
                    self_attn.w_scale = scale

            if w.dtype == torch.int8:
                if hasattr(self.quant_config, "weight_block_size"):
                    # block-wise int8 need it
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv
                        w = int8_block_dequant(
                            weight, weight_scale, weight_block_size
                        ).to(torch.bfloat16)
                else:
                    # channel-wise int8 need it
                    w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
                        torch.bfloat16
                    )
1888

Baizhou Zhang's avatar
Baizhou Zhang committed
1889
1890
1891
            w_kc, w_vc = w.unflatten(
                0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
            ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1892
            if not use_deep_gemm_bmm:
1893
1894
1895
1896
1897
1898
                self_attn.w_kc = bind_or_assign(
                    self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
                )
                self_attn.w_vc = bind_or_assign(
                    self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
                )
1899
1900
1901
1902
                if (
                    hasattr(self_attn.kv_b_proj, "weight_scale")
                    and self_attn.w_scale is None
                ):
1903
1904
1905
                    self_attn.w_scale = bind_or_assign(
                        self_attn.w_scale, self_attn.kv_b_proj.weight_scale
                    )
1906
1907
1908
1909
1910
1911
1912
1913
                    if _is_hip:
                        self_attn.w_scale *= 2.0
            else:
                num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
                num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
                ws_kc, ws_vc = block_scale.unflatten(
                    0, (-1, (num_tiles_k + num_tiles_n))
                ).split([num_tiles_k, num_tiles_n], dim=1)
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
                self_attn.w_scale_k = bind_or_assign(
                    self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
                )
                self_attn.w_scale_v = bind_or_assign(
                    self_attn.w_scale_v, ws_vc.contiguous()
                )
                self_attn.w_kc = bind_or_assign(
                    self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
                )
                self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
1924
                self_attn.use_deep_gemm_bmm = True
inkcherry's avatar
inkcherry committed
1925

1926
1927
1928
        if (
            deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
            and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
1929
1930
            and hasattr(self.quant_config, "weight_block_size")
            and self.quant_config.weight_block_size is not None
1931
        ):
1932
1933
1934
            self._weight_requant_ue8m0()

    def _weight_requant_ue8m0(self):
1935
1936
        if self.config.architectures[0] == "DeepseekV3ForCausalLMNextN":
            return
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
        weight_block_size = self.quant_config.weight_block_size

        moe_layers = list(
            range(
                self.config.first_k_dense_replace,
                self.config.num_hidden_layers,
                self.config.moe_layer_freq,
            )
        )

        for layer_id in range(self.config.num_hidden_layers):
            layer = self.model.layers[layer_id]

            for module in [
                layer.self_attn.fused_qkv_a_proj_with_mqa,
                layer.self_attn.q_b_proj,
                layer.self_attn.kv_b_proj,
                layer.self_attn.o_proj,
            ]:
                requant_weight_ue8m0_inplace(
                    module.weight, module.weight_scale_inv, weight_block_size
                )

            if layer_id in moe_layers:
1961
1962
1963
1964
1965
1966
1967
1968
1969
                shared_experts = getattr(layer.mlp, "shared_experts", None)
                if shared_experts is not None:
                    for module in [
                        shared_experts.gate_up_proj,
                        shared_experts.down_proj,
                    ]:
                        requant_weight_ue8m0_inplace(
                            module.weight, module.weight_scale_inv, weight_block_size
                        )
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988

                experts = layer.mlp.experts
                if isinstance(experts, DeepEPMoE):
                    for w in [
                        experts.w13_weight_fp8,
                        experts.w2_weight_fp8,
                    ]:
                        requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
            else:
                mlp = layer.mlp
                assert isinstance(mlp, DeepseekV2MLP)
                for module in [
                    mlp.gate_up_proj,
                    mlp.down_proj,
                ]:
                    requant_weight_ue8m0_inplace(
                        module.weight, module.weight_scale_inv, weight_block_size
                    )

1989
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
1990

1991
1992
1993
        if is_nextn:
            if hasattr(self.config, "num_nextn_predict_layers"):
                num_nextn_layers = self.config.num_nextn_predict_layers
1994
                assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
1995
1996
1997
1998
1999
2000
2001
2002
2003
                # compatible with old design
                nextn_layer_id = (
                    0
                    if self.config.num_hidden_layers == 1
                    else self.config.num_hidden_layers
                )
            else:
                raise ValueError("num_nextn_predict_layers is not in the config")

Liangsheng Yin's avatar
Liangsheng Yin committed
2004
2005
2006
2007
2008
2009
2010
2011
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
2012
        expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
Liangsheng Yin's avatar
Liangsheng Yin committed
2013
2014
2015
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
2016
            num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
Liangsheng Yin's avatar
Liangsheng Yin committed
2017
2018
        )

2019
2020
2021
2022
2023
2024
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
            self.config.q_lora_rank is not None
        )
        cached_a_proj = {} if fuse_qkv_a_proj else None

2025
2026
2027
2028
2029
2030
2031
2032
2033
        if is_nextn:
            nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
            nextn_spec_weight_names = [
                "shared_head.norm",
                "eh_proj",
                "enorm",
                "hnorm",
            ]

2034
2035
2036
2037
        if self.num_fused_shared_experts > 0:
            assert self.num_fused_shared_experts == 1
            logger.info("Shared experts fusion optimization enabled.")

Liangsheng Yin's avatar
Liangsheng Yin committed
2038
        params_dict = dict(self.named_parameters())
2039
        weight_names = []
Liangsheng Yin's avatar
Liangsheng Yin committed
2040
        for name, loaded_weight in weights:
2041
2042
2043
2044
2045
2046
            if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
                name = name.replace(
                    "mlp.shared_experts",
                    f"mlp.experts.{self.config.n_routed_experts}",
                )

2047
2048
            weight_names.append(name)

2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
            if not is_nextn:
                if hasattr(self.config, "num_nextn_predict_layers"):
                    num_nextn_layers = self.config.num_nextn_predict_layers
                    if num_nextn_layers > 0 and name.startswith("model.layers"):
                        name_list = name.split(".")
                        if (
                            len(name_list) >= 3
                            and int(name_list[2]) >= self.config.num_hidden_layers
                        ):
                            continue
            else:
                if not name.startswith(nextn_layer_prefix):
                    continue

                # Use shared head and embed weights from target model
                if "shared_head.head" in name or "embed_tokens" in name:
                    continue

                is_decoder = True
                # For nextn specific weights
                for weight_name in nextn_spec_weight_names:
                    if weight_name in name:
                        name = name.replace(nextn_layer_prefix, "model")
                        is_decoder = False
                        break
                # For decoder layer weights
                if is_decoder:
                    name = name.replace(nextn_layer_prefix, "model.decoder")

Liangsheng Yin's avatar
Liangsheng Yin committed
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
            if "rotary_emb.inv_freq" in name:
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if ("mlp.experts." in name) and name not in params_dict:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(
                        param,
                        loaded_weight,
2111
                        name,
Liangsheng Yin's avatar
Liangsheng Yin committed
2112
2113
2114
2115
2116
2117
2118
2119
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
                    if fuse_qkv_a_proj and (
                        "q_a_proj" in name or "kv_a_proj_with_mqa" in name
                    ):
                        cached_a_proj[name] = loaded_weight
                        q_a_proj_name = (
                            name
                            if "q_a_proj" in name
                            else name.replace("kv_a_proj_with_mqa", "q_a_proj")
                        )
                        kv_a_proj_name = (
                            name
                            if "kv_a_proj_with_mqa" in name
                            else name.replace("q_a_proj", "kv_a_proj_with_mqa")
                        )

                        # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
                        if (
                            q_a_proj_name in cached_a_proj
                            and kv_a_proj_name in cached_a_proj
                        ):
                            q_a_proj_weight = cached_a_proj[q_a_proj_name]
                            kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
2142
2143
2144
2145
2146
2147
                            cat_dim = 0
                            if (
                                self.quant_config.get_name() == "awq"
                                or self.quant_config.get_name() == "moe_wna16"
                            ):
                                cat_dim = 1
2148
                            fused_weight = torch.cat(
2149
                                [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
2150
                            )
2151
2152
2153
2154
2155
2156
                            param_name = (
                                name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
                                if "q_a_proj" in name
                                else name.replace(
                                    "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
                                )
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
                            )
                            param = params_dict[param_name]

                            weight_loader = getattr(
                                param, "weight_loader", default_weight_loader
                            )
                            weight_loader(param, fused_weight)
                            cached_a_proj.pop(q_a_proj_name)
                            cached_a_proj.pop(kv_a_proj_name)
                    else:
2167
2168
2169
2170
                        if (
                            "k_scale" in name or "v_scale" in name
                        ) and name not in params_dict:
                            # modelopt attn kv scale is named differently
2171
2172
2173
                            for scale in ["k_scale", "v_scale"]:
                                if scale in name:
                                    name = name.replace(f"{scale[0]}_proj", "attn_mqa")
2174
2175
2176
2177
2178
                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
Liangsheng Yin's avatar
Liangsheng Yin committed
2179

2180
        self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
Ke Bao's avatar
Ke Bao committed
2181

2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

2193
2194
2195
2196
2197
2198
2199
2200
    @classmethod
    def get_model_config_for_expert_location(cls, config):
        return ModelConfigForExpertLocation(
            num_layers=config.num_hidden_layers,
            num_logical_experts=config.n_routed_experts,
            num_groups=config.n_group,
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
2201

HandH1998's avatar
HandH1998 committed
2202
2203
2204
2205
2206
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass


EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]