deepseek.py 19.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
26
27
from collections.abc import Iterable
from typing import Any, Optional, Union
28
29
30
31
32

import torch
from torch import nn
from transformers import PretrainedConfig

33
from vllm.attention import Attention
34
from vllm.config import CacheConfig, VllmConfig
35
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
36
37
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_reduce)
38
39
40
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
41
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
42
                                               QKVParallelLinear,
43
                                               ReplicatedLinear,
44
                                               RowParallelLinear)
45
from vllm.model_executor.layers.logits_processor import LogitsProcessor
46
from vllm.model_executor.layers.quantization import QuantizationConfig
47
from vllm.model_executor.layers.rotary_embedding import get_rope
48
from vllm.model_executor.layers.vocab_parallel_embedding import (
49
    ParallelLMHead, VocabParallelEmbedding)
50
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
51
from vllm.model_executor.sampling_metadata import SamplingMetadata
52
from vllm.sequence import IntermediateTensors
53

54
from .interfaces import SupportsPP
55
56
from .utils import (AutoWeightsLoader, extract_layer_index,
                    is_pp_missing_parameter,
57
58
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
59

60
61
62
63
64
65
66
67

class DeepseekMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
68
        quant_config: Optional[QuantizationConfig] = None,
69
        reduce_results: bool = True,
70
        prefix: str = "",
71
72
73
74
75
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
76
            quant_config=quant_config)
77
78
79
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
80
                                           quant_config=quant_config,
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
                                           reduce_results=reduce_results)
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekMoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
99
        quant_config: Optional[QuantizationConfig] = None,
100
        prefix: str = "",
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    ):
        super().__init__()
        self.config = config
        self.rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.n_routed_experts = config.n_routed_experts
        self.top_k = config.num_experts_per_tok
        if self.tp_size > self.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {self.n_routed_experts}.")

        self.experts = nn.ModuleList([
            DeepseekMLP(hidden_size=config.hidden_size,
                        intermediate_size=config.moe_intermediate_size,
                        hidden_act=config.hidden_act,
117
                        quant_config=quant_config,
118
119
120
121
122
123
124
125
                        reduce_results=False)
            for idx in range(self.n_routed_experts)
        ])
        self.pack_params()

        self.gate = ReplicatedLinear(config.hidden_size,
                                     self.n_routed_experts,
                                     bias=False,
126
                                     quant_config=None)
127
128

        if config.n_shared_experts is not None:
129
130
            intermediate_size = (config.moe_intermediate_size *
                                 config.n_shared_experts)
131
132
133
134
            self.shared_experts = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
135
                quant_config=quant_config,
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
                reduce_results=False,
            )

    def pack_params(self):
        w1 = []
        w2 = []
        for expert in self.experts:
            w1.append(expert.gate_up_proj.weight)
            w2.append(expert.down_proj.weight)
        self.w1 = torch._utils._flatten_dense_tensors(w1)
        w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
        for data, param in zip(w1s, w1):
            param.data = data
        self.w1 = self.w1.view(len(w1), *w1s[0].shape)

        self.w2 = torch._utils._flatten_dense_tensors(w2)
        w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
        for data, param in zip(w2s, w2):
            param.data = data

        self.w2 = self.w2.view(len(w2), *w2s[0].shape)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
159
        num_tokens, hidden_dim = hidden_states.shape
160
161
162
        hidden_states = hidden_states.view(-1, hidden_dim)
        if self.config.n_shared_experts is not None:
            shared_output = self.shared_experts(hidden_states)
163
        # router_logits: (num_tokens, n_experts)
164
165
166
167
        router_logits, _ = self.gate(hidden_states)
        final_hidden_states = fused_moe(hidden_states,
                                        self.w1,
                                        self.w2,
168
169
170
                                        router_logits,
                                        self.top_k,
                                        renormalize=self.config.norm_topk_prob,
171
172
173
174
175
176
177
                                        inplace=True)

        if self.config.n_shared_experts is not None:
            final_hidden_states = final_hidden_states + shared_output
        final_hidden_states = tensor_model_parallel_all_reduce(
            final_hidden_states)

178
        return final_hidden_states.view(num_tokens, hidden_dim)
179
180
181
182
183
184
185
186
187
188


class DeepseekAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
189
        rope_scaling: Optional[dict[str, Any]] = None,
190
        max_position_embeddings: int = 8192,
191
        cache_config: Optional[CacheConfig] = None,
192
        quant_config: Optional[QuantizationConfig] = None,
193
        prefix: str = "",
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
224
            quant_config=quant_config,
225
226
227
228
229
230
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
231
            quant_config=quant_config,
232
233
234
235
236
237
238
239
240
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
241
242
243
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
244
                              num_kv_heads=self.num_kv_heads,
245
                              cache_config=cache_config,
246
247
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
248
249
250
251
252
253
254
255
256

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
257
        attn_output = self.attn(q, k, v)
258
259
260
261
262
263
264
265
266
        output, _ = self.o_proj(attn_output)
        return output


class DeepseekDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
267
        cache_config: Optional[CacheConfig] = None,
268
        quant_config: Optional[QuantizationConfig] = None,
269
        prefix: str = "",
270
271
    ) -> None:
        super().__init__()
272
        layer_idx = extract_layer_index(prefix)
273
274
275
276
277
278
279
280
281
282
283
284
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = DeepseekAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
285
            cache_config=cache_config,
286
            quant_config=quant_config,
287
            prefix=f"{prefix}.self_attn",
288
        )
289
290
291
        if (config.n_routed_experts is not None
                and layer_idx >= config.first_k_dense_replace
                and layer_idx % config.moe_layer_freq == 0):
292
293
294
            self.mlp = DeepseekMoE(config=config,
                                   quant_config=quant_config,
                                   prefix=f"{prefix}.mlp")
295
296
297
298
299
        else:
            self.mlp = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
300
                quant_config=quant_config,
301
                prefix=f"{prefix}.mlp",
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class DeepseekModel(nn.Module):

335
336
    fall_back_to_pt_during_load = False

337
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
338
        super().__init__()
339
340
341
342
343

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

344
345
346
347
348
349
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
350
351
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
352
353
354
            lambda prefix: DeepseekDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
355
            prefix=f"{prefix}.layers")
356
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
357
358
359
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
360

361
362
363
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

364
365
366
367
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
368
        intermediate_tensors: Optional[IntermediateTensors],
369
        inputs_embeds: Optional[torch.Tensor] = None,
370
371
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
372
373
374
375
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
376
377
378
379
            residual = None
        else:
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
380
381
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states, residual = layer(positions, hidden_states, residual)
382
383
384
385
386
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
387
388
389
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

390
391
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
392
393
394
395
396
397
398
399
400
401
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
402
        loaded_params: set[str] = set()
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
                if (("mlp.experts." in name or "mlp.shared_experts." in name)
                        and name not in params_dict):
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
                if (("mlp.experts." in name or "mlp.shared_experts." in name)
                        and name not in params_dict):
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

440

441
class DeepseekForCausalLM(nn.Module, SupportsPP):
442

443
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
444
        super().__init__()
445
446
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
447
        self.config = config
448
        self.quant_config = quant_config
449
450
        self.model = DeepseekModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "model"))
451
452
453
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
454
455
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
456
        self.logits_processor = LogitsProcessor(config.vocab_size)
457
458
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
459

460
461
462
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

463
464
465
466
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
467
        intermediate_tensors: Optional[IntermediateTensors] = None,
468
        inputs_embeds: Optional[torch.Tensor] = None,
469
    ) -> Union[torch.Tensor, IntermediateTensors]:
470
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
471
                                   inputs_embeds)
472
473
        return hidden_states

474
475
476
477
478
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
479
        logits = self.logits_processor(self.lm_head, hidden_states,
480
481
482
                                       sampling_metadata)
        return logits

483
484
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
485
486
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)