deepseek_v2.py 38.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

wangding zeng's avatar
wangding zeng committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
24
"""Inference-only DeepseekV2/DeepseekV3 model."""
王敏's avatar
王敏 committed
25
26
import os
import re
27
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
28
import vllm.envs as envs
wangding zeng's avatar
wangding zeng committed
29
30
31
32
33
import torch
from torch import nn
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
34
from vllm.compilation.decorators import support_torch_compile
35
from vllm.config import CacheConfig, ModelConfig, VllmConfig
36
37
from vllm.distributed import (get_pp_group,
                              get_tensor_model_parallel_world_size,
wangding zeng's avatar
wangding zeng committed
38
39
                              tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
40
from vllm.model_executor.layers.fused_moe import FusedMoE
wangding zeng's avatar
wangding zeng committed
41
42
43
44
45
46
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
47
from vllm.model_executor.layers.quantization import QuantizationConfig
wangding zeng's avatar
wangding zeng committed
48
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
49
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
wangding zeng's avatar
wangding zeng committed
50
51
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
52
53
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
wangding zeng's avatar
wangding zeng committed
54
from vllm.model_executor.sampling_metadata import SamplingMetadata
55
from vllm.sequence import IntermediateTensors
wangding zeng's avatar
wangding zeng committed
56

57
58
from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
59
60
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
王敏's avatar
王敏 committed
61
from vllm import _custom_ops as ops
62

wangding zeng's avatar
wangding zeng committed
63
64
65
66
67
68
69
70
71
72

class DeepseekV2MLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
73
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
74
75
76
77
78
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
79
80
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
wangding zeng's avatar
wangding zeng committed
81
82
83
84
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           quant_config=quant_config,
85
86
                                           reduce_results=reduce_results,
                                           prefix=f"{prefix}.down_proj")
wangding zeng's avatar
wangding zeng committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
105
        prefix: str = "",
王敏's avatar
王敏 committed
106
        moe_ep_size: int = 1
wangding zeng's avatar
wangding zeng committed
107
108
109
110
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
111
112
113
        self.n_shared_experts = config.n_shared_experts
        self.routed_scaling_factor = config.routed_scaling_factor
        if self.tp_size > config.n_routed_experts:
wangding zeng's avatar
wangding zeng committed
114
115
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
116
117
118
119
120
121
                f"the number of experts {config.n_routed_experts}.")

        if config.hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {config.hidden_act}. "
                             "Only silu is supported for now.")

wangding zeng's avatar
wangding zeng committed
122
        self.gate = ReplicatedLinear(config.hidden_size,
123
                                     config.n_routed_experts,
wangding zeng's avatar
wangding zeng committed
124
                                     bias=False,
125
126
                                     quant_config=None,
                                     prefix=f"{prefix}.gate")
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        if config.topk_method == "noaux_tc":
            self.gate.e_score_correction_bias = nn.Parameter(
                torch.empty(config.n_routed_experts))
        else:
            self.gate.e_score_correction_bias = None

        self.experts = FusedMoE(
            num_experts=config.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            topk_group=config.topk_group,
            prefix=f"{prefix}.experts",
            scoring_func=config.scoring_func,
王敏's avatar
王敏 committed
146
147
            e_score_correction_bias=self.gate.e_score_correction_bias,
            moe_ep_size=moe_ep_size)
148

wangding zeng's avatar
wangding zeng committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        if config.n_shared_experts is not None:
            intermediate_size = (config.moe_intermediate_size *
                                 config.n_shared_experts)
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
            )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
163
        if self.n_shared_experts is not None:
wangding zeng's avatar
wangding zeng committed
164
165
166
            shared_output = self.shared_experts(hidden_states)
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
167
168
169
170
        final_hidden_states = self.experts(
            hidden_states=hidden_states,
            router_logits=router_logits) * self.routed_scaling_factor
        if shared_output is not None:
wangding zeng's avatar
wangding zeng committed
171
            final_hidden_states = final_hidden_states + shared_output
172
173
174
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(
                final_hidden_states)
wangding zeng's avatar
wangding zeng committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

        return final_hidden_states.view(num_tokens, hidden_dim)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


class DeepseekV2Attention(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
203
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
            self.q_a_proj = ReplicatedLinear(self.hidden_size,
                                             self.q_lora_rank,
                                             bias=False,
225
226
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.q_a_proj")
wangding zeng's avatar
wangding zeng committed
227
228
229
230
231
232
            self.q_a_layernorm = RMSNorm(self.q_lora_rank,
                                         eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(q_lora_rank,
                                                 self.num_heads *
                                                 self.qk_head_dim,
                                                 bias=False,
233
234
                                                 quant_config=quant_config,
                                                 prefix=f"{prefix}.q_b_proj")
wangding zeng's avatar
wangding zeng committed
235
236
237
238
239
        else:
            self.q_proj = ColumnParallelLinear(self.hidden_size,
                                               self.num_heads *
                                               self.qk_head_dim,
                                               bias=False,
240
241
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.q_proj")
wangding zeng's avatar
wangding zeng committed
242

243
244
245
246
247
248
        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_a_proj_with_mqa")
wangding zeng's avatar
wangding zeng committed
249
250
251
252
253
254
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
                                      eps=config.rms_norm_eps)
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
255
256
            quant_config=quant_config,
            prefix=f"{prefix}.kv_b_proj")
wangding zeng's avatar
wangding zeng committed
257
258
259
260
        # O projection.
        self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
                                        self.hidden_size,
                                        bias=False,
261
262
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.o_proj")
263
264
265
266
267
        if rope_scaling:
            rope_scaling["rope_type"] = 'deepseek_yarn'
            self.use_normal_rope = False
        else:
            self.use_normal_rope = True
wangding zeng's avatar
wangding zeng committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        self.rotary_emb = get_rope(qk_rope_head_dim,
                                   rotary_dim=qk_rope_head_dim,
                                   max_position=max_position_embeddings,
                                   base=rope_theta,
                                   rope_scaling=rope_scaling,
                                   is_neox_style=False)

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

        self.attn = Attention(self.num_local_heads,
282
                              self.qk_head_dim,
wangding zeng's avatar
wangding zeng committed
283
284
285
                              self.scaling,
                              num_kv_heads=self.num_local_heads,
                              cache_config=cache_config,
286
287
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
wangding zeng's avatar
wangding zeng committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
                                         self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
                                                   self.qk_head_dim)
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
                               dim=-1)
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
        kv_a, _ = latent_cache.split(
            [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads,
                     self.qk_nope_head_dim + self.v_head_dim)
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
        k_pe = latent_cache[:, :, self.kv_lora_rank:]
316
317
318
319
320
321
322

        if self.use_normal_rope:
            seq_len = positions.size(0)
            ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
            q_pe = q_pe.reshape(seq_len, -1)
            k_pe = k_pe.reshape(seq_len, -1)

wangding zeng's avatar
wangding zeng committed
323
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
324
325
326
327

        if self.use_normal_rope:
            q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)

wangding zeng's avatar
wangding zeng committed
328
329
330
331
        q[..., self.qk_nope_head_dim:] = q_pe
        k = torch.empty_like(q)
        k[..., :self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim:] = k_pe
332
333
334
335
        # padding value to qk_head_dim for alignment
        v = torch.nn.functional.pad(
            v, [0, self.qk_head_dim - self.v_head_dim],
            value=0).view(-1, self.num_local_heads * self.qk_head_dim)
wangding zeng's avatar
wangding zeng committed
336
337
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        attn_output = attn_output.view(
338
339
            -1, self.num_local_heads,
            self.qk_head_dim)[..., :self.v_head_dim].reshape(
wangding zeng's avatar
wangding zeng committed
340
341
342
343
344
                -1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output


345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
class DeepseekV2MLAAttention(nn.Module):
    """
    Main reference: DeepseekV2 paper, and FlashInfer Implementation
    (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
    
    For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
    """

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: Optional[int],
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank

        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size

        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
            self.q_a_proj = ReplicatedLinear(self.hidden_size,
                                             self.q_lora_rank,
                                             bias=False,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.q_a_proj")
            self.q_a_layernorm = RMSNorm(self.q_lora_rank,
                                         eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(q_lora_rank,
                                                 self.num_heads *
                                                 self.qk_head_dim,
                                                 bias=False,
                                                 quant_config=quant_config,
                                                 prefix=f"{prefix}.q_b_proj")
        else:
            self.q_proj = ColumnParallelLinear(self.hidden_size,
                                               self.num_heads *
                                               self.qk_head_dim,
                                               bias=False,
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.q_proj")

        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_a_proj_with_mqa")
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
                                      eps=config.rms_norm_eps)
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_b_proj")
        self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
                                        self.hidden_size,
                                        bias=False,
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.o_proj")

431
432
        if rope_scaling:
            rope_scaling["rope_type"] = 'deepseek_yarn'
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
        self.rotary_emb = get_rope(qk_rope_head_dim,
                                   rotary_dim=qk_rope_head_dim,
                                   max_position=max_position_embeddings,
                                   base=rope_theta,
                                   rope_scaling=rope_scaling,
                                   is_neox_style=False)
        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

        self.mla_attn = Attention(
            num_heads=self.num_local_heads,
            head_size=self.kv_lora_rank,
            scale=self.scaling,
            num_kv_heads=1,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_mla=True,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_head_dim,
            v_head_dim=self.v_head_dim,
            rotary_emb=self.rotary_emb,
            q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
            kv_b_proj=self.kv_b_proj,
            o_proj=self.o_proj,
        )

        self.prefix = prefix
        self.debug_layer_idx = int(self.prefix.split(".")[-2])

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            ckq = self.q_a_proj(hidden_states)[0]
            hidden_states_or_q_c = self.q_a_layernorm(ckq)
        else:
            hidden_states_or_q_c = hidden_states
        kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
            [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
        return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
                             attn_metadata)


wangding zeng's avatar
wangding zeng committed
489
490
491
492
493
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
494
        prefix: str,
495
        model_config: ModelConfig,
wangding zeng's avatar
wangding zeng committed
496
497
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
王敏's avatar
王敏 committed
498
        moe_ep_size : int = 1,
wangding zeng's avatar
wangding zeng committed
499
500
501
502
503
504
505
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
506
507
508
        # DecoderLayers are created with `make_layers` which passes the prefix
        # with the layer's index.
        layer_idx = int(prefix.split(sep='.')[-1])
509
510
511
512
513
        if model_config.use_mla:
            attn_cls = DeepseekV2MLAAttention
        else:
            attn_cls = DeepseekV2Attention
        self.self_attn = attn_cls(
wangding zeng's avatar
wangding zeng committed
514
515
516
517
518
519
520
521
522
523
524
525
526
527
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=config.q_lora_rank
            if hasattr(config, "q_lora_rank") else None,
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
528
            prefix=f"{prefix}.self_attn",
wangding zeng's avatar
wangding zeng committed
529
        )
530

wangding zeng's avatar
wangding zeng committed
531
532
533
        if (config.n_routed_experts is not None
                and layer_idx >= config.first_k_dense_replace
                and layer_idx % config.moe_layer_freq == 0):
534
535
536
537
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
王敏's avatar
王敏 committed
538
                moe_ep_size=moe_ep_size
539
            )
wangding zeng's avatar
wangding zeng committed
540
541
542
543
544
545
        else:
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
546
                prefix=f"{prefix}.mlp",
wangding zeng's avatar
wangding zeng committed
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


582
@support_torch_compile
wangding zeng's avatar
wangding zeng committed
583
584
585
586
class DeepseekV2Model(nn.Module):

    fall_back_to_pt_during_load = False

王敏's avatar
王敏 committed
587
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", moe_ep_size: int = 1):
wangding zeng's avatar
wangding zeng committed
588
        super().__init__()
589
590

        config = vllm_config.model_config.hf_config
591
        model_config = vllm_config.model_config
592
593
594
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

wangding zeng's avatar
wangding zeng committed
595
596
597
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

598
599
600
601
602
603
604
605
606
607
608
609
610
        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
            )
        else:
            self.embed_tokens = PPMissingLayer()

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: DeepseekV2DecoderLayer(
                config,
                prefix,
611
                model_config=model_config,
612
613
                cache_config=cache_config,
                quant_config=quant_config,
王敏's avatar
王敏 committed
614
                moe_ep_size=moe_ep_size
615
616
617
618
619
620
621
            ),
            prefix=f"{prefix}.layers")

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
622
623
624
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
wangding zeng's avatar
wangding zeng committed
625

626
627
628
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

wangding zeng's avatar
wangding zeng committed
629
630
631
632
633
634
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
635
        intermediate_tensors: Optional[IntermediateTensors],
636
        inputs_embeds: Optional[torch.Tensor] = None,
637
    ) -> Union[torch.Tensor, IntermediateTensors]:
638
        if get_pp_group().is_first_rank:
639
640
641
642
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
643
644
645
646
647
648
649
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
wangding zeng's avatar
wangding zeng committed
650
651
            layer = self.layers[i]
            hidden_states, residual = layer(positions, hidden_states,
652
653
654
655
656
657
658
659
660
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

wangding zeng's avatar
wangding zeng committed
661
662
663
664
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


665
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
wangding zeng's avatar
wangding zeng committed
666

667
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
wangding zeng's avatar
wangding zeng committed
668
        super().__init__()
669
670
        # 暂时awq不支持cutlass
        envs.VLLM_USE_TRITON_AWQ = True
671
672
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
673
674
675
676
677
678
679

        self.quant_method = None
        if quant_config is not None:
            self.quant_method = quant_config.get_name()
            os.environ['LLAMA_NN'] = '0'
            os.environ['LM_NN'] = '0'

wangding zeng's avatar
wangding zeng committed
680
681
        self.config = config
        self.quant_config = quant_config
王敏's avatar
王敏 committed
682
683
684
        self.parallel_config = vllm_config.parallel_config
        self.moe_ep_size = self.parallel_config.moe_ep_size

685
        self.model = DeepseekV2Model(vllm_config=vllm_config,
王敏's avatar
王敏 committed
686
687
                                     prefix=maybe_prefix(prefix, "model"),
									 moe_ep_size=self.moe_ep_size)
688
689
690
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
wangding zeng's avatar
wangding zeng committed
691
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
692
        self.sampler = get_sampler()
693
694
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
王敏's avatar
王敏 committed
695
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
696
697
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
        
698
699
700
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

wangding zeng's avatar
wangding zeng committed
701
702
703
704
705
706
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
707
        intermediate_tensors: Optional[IntermediateTensors] = None,
708
        inputs_embeds: Optional[torch.Tensor] = None,
709
    ) -> Union[torch.Tensor, IntermediateTensors]:
wangding zeng's avatar
wangding zeng committed
710
        hidden_states = self.model(input_ids, positions, kv_caches,
711
712
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
wangding zeng's avatar
wangding zeng committed
713
714
        return hidden_states

715
716
717
718
719
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
720
        logits = self.logits_processor(self.lm_head, hidden_states,
wangding zeng's avatar
wangding zeng committed
721
722
723
724
725
726
727
728
729
730
731
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: Optional[torch.Tensor],
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

732
733
734
735
736
737
738
739
740
741
742
743
744
745
    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
            "residual":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

746
747
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
wangding zeng's avatar
wangding zeng committed
748
749
750
751
752
753
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

754
755
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
王敏's avatar
王敏 committed
756
757
758
759
760
761
762
763
764
765
766
767
768
        if self.moe_ep_size == 1:
            expert_params_mapping = FusedMoE.make_expert_params_mapping(
                ckpt_gate_proj_name="gate_proj",
                ckpt_down_proj_name="down_proj",
                ckpt_up_proj_name="up_proj",
                num_experts=self.config.n_routed_experts)
        else:
            expert_params_mapping = FusedMoE.make_expert_params_mapping_ep(
                ckpt_gate_proj_name="gate_proj",
                ckpt_down_proj_name="down_proj",
                ckpt_up_proj_name="up_proj",
                num_experts=self.config.n_routed_experts,
                moe_ep_size=self.moe_ep_size)
769

wangding zeng's avatar
wangding zeng committed
770
        params_dict = dict(self.named_parameters())
771
        loaded_params: Set[str] = set()
wangding zeng's avatar
wangding zeng committed
772
773
774
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
775
776
777
778
779
780
781
782
783

            # TODO(simon): support nextn predict layers
            if hasattr(self.config, "num_nextn_predict_layers"
                       ) and self.config.num_nextn_predict_layers > 0:
                assert self.config.num_nextn_predict_layers == 1
                layer_idx = self.config.num_hidden_layers
                if name.startswith(f"model.layers.{layer_idx}"):
                    continue

wangding zeng's avatar
wangding zeng committed
784
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
785
                # Skip non-stacked layers and experts (experts handled below).
wangding zeng's avatar
wangding zeng committed
786
787
                if weight_name not in name:
                    continue
788
789
790
791
792
793
794
795
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if (("mlp.experts." in name) and name not in params_dict):
                    continue
wangding zeng's avatar
wangding zeng committed
796
797
798
799
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
800
801
802
803

                if is_pp_missing_parameter(name, self):
                    continue

wangding zeng's avatar
wangding zeng committed
804
805
806
807
808
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
809
810
811
812
813
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
814
815
816
817

                    if is_pp_missing_parameter(name, self):
                        continue

818
819
820
821
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
822
                                  name,
823
824
825
826
827
828
829
830
                                  shard_id=shard_id,
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

831
832
833
834
835
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

836
                    if is_pp_missing_parameter(name, self):
王敏's avatar
王敏 committed
837
838
839
840
                        continue
						
					# Skip loading extra expert weights for ep moe mode
                    if name not in params_dict:
841
842
                        continue

843
844
845
846
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
847
            loaded_params.add(name)
王敏's avatar
王敏 committed
848
            
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "self_attn.q_proj.weight",
                "self_attn.q_a_proj.weight",
                "self_attn.q_b_proj.weight",
                "self_attn.kv_a_proj_with_mqa.weight",
                "self_attn.kv_b_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
                "mlp.gate.weight",
                "shared_experts.gate_up_proj.weight",
                "shared_experts.down_proj.weight",
                "lm_head.weight",
            ]

            combined_words = "|".join(lay_key_words)
            
            for layername in loaded_params:
                weight = params_dict[layername]
                matches = re.findall(combined_words, layername)
                if matches:
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
        # 暂时不支持TN   
        if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            lay_key_words = [
                "self_attn.q_a_proj.qweight",
                "self_attn.q_b_proj.qweight",
                "self_attn.kv_a_proj_with_mqa.qweight",
                "self_attn.kv_b_proj.qweight",
                "self_attn.o_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.down_proj.qweight",
                "mlp.shared_experts.gate_up_proj.qweight",
                "mlp.shared_experts.down_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername in loaded_params:
                weight = params_dict[layername]
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
                    
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
                    
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
                    if dim_k % 4096==0 and self.use_awq_pad:
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
王敏's avatar
王敏 committed
925

926
        return loaded_params
927
928
929


class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
zhuwenwen's avatar
zhuwenwen committed
930
    pass