deepseek.py 19.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
26
from collections.abc import Iterable
27
from itertools import islice
28
from typing import Any, Optional, Union
29
30
31
32
33

import torch
from torch import nn
from transformers import PretrainedConfig

34
from vllm.attention import Attention
35
from vllm.config import CacheConfig, VllmConfig
36
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
37
38
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_reduce)
39
from vllm.model_executor.layers.activation import SiluAndMul
40
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
41
from vllm.model_executor.layers.layernorm import RMSNorm
42
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
43
                                               QKVParallelLinear,
44
                                               ReplicatedLinear,
45
                                               RowParallelLinear)
46
from vllm.model_executor.layers.logits_processor import LogitsProcessor
47
from vllm.model_executor.layers.quantization import QuantizationConfig
48
from vllm.model_executor.layers.rotary_embedding import get_rope
49
from vllm.model_executor.layers.vocab_parallel_embedding import (
50
    ParallelLMHead, VocabParallelEmbedding)
51
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
52
from vllm.sequence import IntermediateTensors
53

54
from .interfaces import SupportsLoRA, SupportsPP
55
56
from .utils import (AutoWeightsLoader, extract_layer_index,
                    is_pp_missing_parameter,
57
58
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
59

60
61
62
63
64
65
66
67

class DeepseekMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
68
        quant_config: Optional[QuantizationConfig] = None,
69
        reduce_results: bool = True,
70
        prefix: str = "",
71
72
73
74
75
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
76
            quant_config=quant_config)
77
78
79
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
80
                                           quant_config=quant_config,
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
                                           reduce_results=reduce_results)
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekMoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
99
        quant_config: Optional[QuantizationConfig] = None,
100
        prefix: str = "",
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    ):
        super().__init__()
        self.config = config
        self.rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.n_routed_experts = config.n_routed_experts
        self.top_k = config.num_experts_per_tok
        if self.tp_size > self.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {self.n_routed_experts}.")

        self.experts = nn.ModuleList([
            DeepseekMLP(hidden_size=config.hidden_size,
                        intermediate_size=config.moe_intermediate_size,
                        hidden_act=config.hidden_act,
117
                        quant_config=quant_config,
118
119
120
121
122
123
124
125
                        reduce_results=False)
            for idx in range(self.n_routed_experts)
        ])
        self.pack_params()

        self.gate = ReplicatedLinear(config.hidden_size,
                                     self.n_routed_experts,
                                     bias=False,
126
                                     quant_config=None)
127
128

        if config.n_shared_experts is not None:
129
130
            intermediate_size = (config.moe_intermediate_size *
                                 config.n_shared_experts)
131
132
133
134
            self.shared_experts = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
135
                quant_config=quant_config,
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
                reduce_results=False,
            )

    def pack_params(self):
        w1 = []
        w2 = []
        for expert in self.experts:
            w1.append(expert.gate_up_proj.weight)
            w2.append(expert.down_proj.weight)
        self.w1 = torch._utils._flatten_dense_tensors(w1)
        w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
        for data, param in zip(w1s, w1):
            param.data = data
        self.w1 = self.w1.view(len(w1), *w1s[0].shape)

        self.w2 = torch._utils._flatten_dense_tensors(w2)
        w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
        for data, param in zip(w2s, w2):
            param.data = data

        self.w2 = self.w2.view(len(w2), *w2s[0].shape)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
159
        num_tokens, hidden_dim = hidden_states.shape
160
161
162
        hidden_states = hidden_states.view(-1, hidden_dim)
        if self.config.n_shared_experts is not None:
            shared_output = self.shared_experts(hidden_states)
163
        # router_logits: (num_tokens, n_experts)
164
        router_logits, _ = self.gate(hidden_states)
165
166
167
168
169
170
171
172
173
174
175
176
177

        topk_weights, topk_ids, _ = fused_topk(
            hidden_states,
            router_logits,
            self.top_k,
            renormalize=self.config.norm_topk_prob)

        final_hidden_states = fused_experts(hidden_states,
                                            self.w1,
                                            self.w2,
                                            topk_weights,
                                            topk_ids,
                                            inplace=True)
178
179
180
181
182
183

        if self.config.n_shared_experts is not None:
            final_hidden_states = final_hidden_states + shared_output
        final_hidden_states = tensor_model_parallel_all_reduce(
            final_hidden_states)

184
        return final_hidden_states.view(num_tokens, hidden_dim)
185
186
187
188
189
190
191
192
193
194


class DeepseekAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
195
        rope_scaling: Optional[dict[str, Any]] = None,
196
        max_position_embeddings: int = 8192,
197
        cache_config: Optional[CacheConfig] = None,
198
        quant_config: Optional[QuantizationConfig] = None,
199
        prefix: str = "",
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
230
            quant_config=quant_config,
231
232
233
234
235
236
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
237
            quant_config=quant_config,
238
239
240
241
242
243
244
245
246
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
247
248
249
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
250
                              num_kv_heads=self.num_kv_heads,
251
                              cache_config=cache_config,
252
253
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
254
255
256
257
258
259
260
261
262

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
263
        attn_output = self.attn(q, k, v)
264
265
266
267
268
269
270
271
272
        output, _ = self.o_proj(attn_output)
        return output


class DeepseekDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
273
        cache_config: Optional[CacheConfig] = None,
274
        quant_config: Optional[QuantizationConfig] = None,
275
        prefix: str = "",
276
277
    ) -> None:
        super().__init__()
278
        layer_idx = extract_layer_index(prefix)
279
280
281
282
283
284
285
286
287
288
289
290
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = DeepseekAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
291
            cache_config=cache_config,
292
            quant_config=quant_config,
293
            prefix=f"{prefix}.self_attn",
294
        )
295
296
297
        if (config.n_routed_experts is not None
                and layer_idx >= config.first_k_dense_replace
                and layer_idx % config.moe_layer_freq == 0):
298
299
300
            self.mlp = DeepseekMoE(config=config,
                                   quant_config=quant_config,
                                   prefix=f"{prefix}.mlp")
301
302
303
304
305
        else:
            self.mlp = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
306
                quant_config=quant_config,
307
                prefix=f"{prefix}.mlp",
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class DeepseekModel(nn.Module):

341
342
    fall_back_to_pt_during_load = False

343
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
344
        super().__init__()
345
346
347
348
349

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

350
351
352
353
354
355
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
356
357
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
358
359
360
            lambda prefix: DeepseekDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
361
            prefix=f"{prefix}.layers")
362
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
363
364
365
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
366

367
368
369
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

370
371
372
373
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
374
        intermediate_tensors: Optional[IntermediateTensors],
375
        inputs_embeds: Optional[torch.Tensor] = None,
376
377
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
378
379
380
381
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
382
383
384
385
            residual = None
        else:
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
386
        for layer in islice(self.layers, self.start_layer, self.end_layer):
387
            hidden_states, residual = layer(positions, hidden_states, residual)
388
389
390
391
392
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
393
394
395
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

396
397
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
398
399
400
401
402
403
404
405
406
407
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
408
        loaded_params: set[str] = set()
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
                if (("mlp.experts." in name or "mlp.shared_experts." in name)
                        and name not in params_dict):
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
                if (("mlp.experts." in name or "mlp.shared_experts." in name)
                        and name not in params_dict):
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

446

447
448
449
450
451
class DeepseekForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
    }
452

453
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
454
        super().__init__()
455
456
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
457
        self.config = config
458
        self.quant_config = quant_config
459
460
        self.model = DeepseekModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "model"))
461
462
463
464
465
466
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
467
468
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
469
        self.logits_processor = LogitsProcessor(config.vocab_size)
470
471
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
472

473
474
475
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

476
477
478
479
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
480
        intermediate_tensors: Optional[IntermediateTensors] = None,
481
        inputs_embeds: Optional[torch.Tensor] = None,
482
    ) -> Union[torch.Tensor, IntermediateTensors]:
483
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
484
                                   inputs_embeds)
485
486
        return hidden_states

487
488
489
490
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
491
        logits = self.logits_processor(self.lm_head, hidden_states)
492
493
        return logits

494
495
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
496
        loader = AutoWeightsLoader(self)
497
        return loader.load_weights(weights)