deepseek.py 18.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
26

27
from collections.abc import Iterable
28
from itertools import islice
29
from typing import Any
30
31
32
33
34

import torch
from torch import nn
from transformers import PretrainedConfig

35
from vllm.attention import Attention
36
from vllm.config import CacheConfig, VllmConfig
37
38
39
40
41
42
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_reduce,
)
43
from vllm.model_executor.layers.activation import SiluAndMul
44
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
45
from vllm.model_executor.layers.layernorm import RMSNorm
46
47
48
49
50
51
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
52
from vllm.model_executor.layers.logits_processor import LogitsProcessor
53
from vllm.model_executor.layers.quantization import QuantizationConfig
54
from vllm.model_executor.layers.rotary_embedding import get_rope
55
from vllm.model_executor.layers.vocab_parallel_embedding import (
56
57
58
    ParallelLMHead,
    VocabParallelEmbedding,
)
59
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
60
from vllm.sequence import IntermediateTensors
61

62
from .interfaces import SupportsLoRA, SupportsPP
63
64
65
66
67
68
69
70
from .utils import (
    AutoWeightsLoader,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
71

72
73
74
75
76
77
78

class DeepseekMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
79
        quant_config: QuantizationConfig | None = None,
80
        reduce_results: bool = True,
81
        prefix: str = "",
82
83
84
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
85
86
87
88
89
            hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
90
            bias=False,
91
92
93
            quant_config=quant_config,
            reduce_results=reduce_results,
        )
94
        if hidden_act != "silu":
95
96
97
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
98
99
100
101
102
103
104
105
106
107
108
109
110
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekMoE(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
111
        quant_config: QuantizationConfig | None = None,
112
        prefix: str = "",
113
114
115
116
117
118
119
120
121
122
    ):
        super().__init__()
        self.config = config
        self.rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.n_routed_experts = config.n_routed_experts
        self.top_k = config.num_experts_per_tok
        if self.tp_size > self.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
                f"the number of experts {self.n_routed_experts}."
            )

        self.experts = nn.ModuleList(
            [
                DeepseekMLP(
                    hidden_size=config.hidden_size,
                    intermediate_size=config.moe_intermediate_size,
                    hidden_act=config.hidden_act,
                    quant_config=quant_config,
                    reduce_results=False,
                )
                for idx in range(self.n_routed_experts)
            ]
        )
138
139
        self.pack_params()

140
141
142
        self.gate = ReplicatedLinear(
            config.hidden_size, self.n_routed_experts, bias=False, quant_config=None
        )
143
144

        if config.n_shared_experts is not None:
145
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
146
147
148
149
            self.shared_experts = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
150
                quant_config=quant_config,
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
                reduce_results=False,
            )

    def pack_params(self):
        w1 = []
        w2 = []
        for expert in self.experts:
            w1.append(expert.gate_up_proj.weight)
            w2.append(expert.down_proj.weight)
        self.w1 = torch._utils._flatten_dense_tensors(w1)
        w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
        for data, param in zip(w1s, w1):
            param.data = data
        self.w1 = self.w1.view(len(w1), *w1s[0].shape)

        self.w2 = torch._utils._flatten_dense_tensors(w2)
        w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
        for data, param in zip(w2s, w2):
            param.data = data

        self.w2 = self.w2.view(len(w2), *w2s[0].shape)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
174
        num_tokens, hidden_dim = hidden_states.shape
175
176
177
        hidden_states = hidden_states.view(-1, hidden_dim)
        if self.config.n_shared_experts is not None:
            shared_output = self.shared_experts(hidden_states)
178
        # router_logits: (num_tokens, n_experts)
179
        router_logits, _ = self.gate(hidden_states)
180
181
182
183
184

        topk_weights, topk_ids, _ = fused_topk(
            hidden_states,
            router_logits,
            self.top_k,
185
186
            renormalize=self.config.norm_topk_prob,
        )
187

188
189
190
        final_hidden_states = fused_experts(
            hidden_states, self.w1, self.w2, topk_weights, topk_ids, inplace=True
        )
191
192
193

        if self.config.n_shared_experts is not None:
            final_hidden_states = final_hidden_states + shared_output
194
        final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
195

196
        return final_hidden_states.view(num_tokens, hidden_dim)
197
198
199
200
201
202
203
204
205


class DeepseekAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
206
        rope_scaling: dict[str, Any] | None = None,
207
        max_position_embeddings: int = 8192,
208
209
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
210
        prefix: str = "",
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
241
            quant_config=quant_config,
242
243
244
245
246
247
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
248
            quant_config=quant_config,
249
250
251
252
253
254
255
256
257
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
258
259
260
261
262
263
264
265
266
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
267
268
269
270
271
272
273
274
275

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
276
        attn_output = self.attn(q, k, v)
277
278
279
280
281
282
283
284
        output, _ = self.o_proj(attn_output)
        return output


class DeepseekDecoderLayer(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
285
286
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
287
        prefix: str = "",
288
289
    ) -> None:
        super().__init__()
290
        layer_idx = extract_layer_index(prefix)
291
292
293
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
294
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
295
        moe_layer_freq = getattr(config, "moe_layer_freq", 1)
296
297
298
299
300
301
302
        self.self_attn = DeepseekAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
303
            cache_config=cache_config,
304
            quant_config=quant_config,
305
            prefix=f"{prefix}.self_attn",
306
        )
307
308
309
        if (
            config.n_routed_experts is not None
            and layer_idx >= config.first_k_dense_replace
310
            and layer_idx % moe_layer_freq == 0
311
312
313
314
        ):
            self.mlp = DeepseekMoE(
                config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
            )
315
316
317
318
319
        else:
            self.mlp = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
320
                quant_config=quant_config,
321
                prefix=f"{prefix}.mlp",
322
            )
323
324
325
326
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
327
328
329
330
331

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
332
        residual: torch.Tensor | None,
333
334
335
336
337
338
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
339
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
340
341
342
343
344
345
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
346
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
347
348
349
350
351
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class DeepseekModel(nn.Module):
352
353
    fall_back_to_pt_during_load = False

354
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
355
        super().__init__()
356
357
358
359
360

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

361
362
363
364
365
366
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
367
368
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
369
370
371
            lambda prefix: DeepseekDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
372
373
            prefix=f"{prefix}.layers",
        )
374
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
375
376
377
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
378

379
380
381
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

382
383
384
385
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
386
387
388
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
389
        if get_pp_group().is_first_rank:
390
391
392
393
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
394
395
396
397
            residual = None
        else:
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
398
        for layer in islice(self.layers, self.start_layer, self.end_layer):
399
            hidden_states, residual = layer(positions, hidden_states, residual)
400
        if not get_pp_group().is_last_rank:
401
402
403
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
404
405
406
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

407
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
408
409
410
411
412
413
414
415
416
417
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
418
        loaded_params: set[str] = set()
419
420
421
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
422
            for param_name, weight_name, shard_id in stacked_params_mapping:
423
424
425
426
427
428
429
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
430
431
432
                if (
                    "mlp.experts." in name or "mlp.shared_experts." in name
                ) and name not in params_dict:
433
434
435
436
437
438
439
440
441
442
443
444
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
445
446
447
                if (
                    "mlp.experts." in name or "mlp.shared_experts." in name
                ) and name not in params_dict:
448
449
450
451
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
452
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
453
454
455
456
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

457

458
459
460
461
462
class DeepseekForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
    }
463

464
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
465
        super().__init__()
466
467
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
468
        self.config = config
469
        self.quant_config = quant_config
470
471
472
        self.model = DeepseekModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
473
474
475
476
477
478
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
479
480
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
481
        self.logits_processor = LogitsProcessor(config.vocab_size)
482
        self.make_empty_intermediate_tensors = (
483
484
            self.model.make_empty_intermediate_tensors
        )
485

486
487
488
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

489
490
491
492
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
493
494
495
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
496
497
498
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
499
500
        return hidden_states

501
502
503
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
504
    ) -> torch.Tensor | None:
505
        logits = self.logits_processor(self.lm_head, hidden_states)
506
507
        return logits

508
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
509
        loader = AutoWeightsLoader(self)
510
        return loader.load_weights(weights)