deepseek.py 17.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
24
from typing import Any, Dict, Iterable, List, Optional, Tuple
25
26
27
28
29

import torch
from torch import nn
from transformers import PretrainedConfig

30
from vllm.attention import Attention, AttentionMetadata
31
from vllm.config import CacheConfig
32
33
34
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_reduce)
35
36
37
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
38
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
39
                                               QKVParallelLinear,
40
                                               ReplicatedLinear,
41
                                               RowParallelLinear)
42
from vllm.model_executor.layers.logits_processor import LogitsProcessor
43
44
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
45
from vllm.model_executor.layers.rotary_embedding import get_rope
46
47
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
48
    ParallelLMHead, VocabParallelEmbedding)
49
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
50
51
52
53
54
55
56
57
58
59
60
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput


class DeepseekMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
61
        quant_config: Optional[QuantizationConfig] = None,
62
63
64
65
66
67
        reduce_results: bool = True,
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
68
            quant_config=quant_config)
69
70
71
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
72
                                           quant_config=quant_config,
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
                                           reduce_results=reduce_results)
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekMoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
91
        quant_config: Optional[QuantizationConfig] = None,
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    ):
        super().__init__()
        self.config = config
        self.rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.n_routed_experts = config.n_routed_experts
        self.top_k = config.num_experts_per_tok
        if self.tp_size > self.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {self.n_routed_experts}.")

        self.experts = nn.ModuleList([
            DeepseekMLP(hidden_size=config.hidden_size,
                        intermediate_size=config.moe_intermediate_size,
                        hidden_act=config.hidden_act,
108
                        quant_config=quant_config,
109
110
111
112
113
114
115
116
                        reduce_results=False)
            for idx in range(self.n_routed_experts)
        ])
        self.pack_params()

        self.gate = ReplicatedLinear(config.hidden_size,
                                     self.n_routed_experts,
                                     bias=False,
117
                                     quant_config=None)
118
119

        if config.n_shared_experts is not None:
120
121
            intermediate_size = (config.moe_intermediate_size *
                                 config.n_shared_experts)
122
123
124
125
            self.shared_experts = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
126
                quant_config=quant_config,
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
                reduce_results=False,
            )

    def pack_params(self):
        w1 = []
        w2 = []
        for expert in self.experts:
            w1.append(expert.gate_up_proj.weight)
            w2.append(expert.down_proj.weight)
        self.w1 = torch._utils._flatten_dense_tensors(w1)
        w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
        for data, param in zip(w1s, w1):
            param.data = data
        self.w1 = self.w1.view(len(w1), *w1s[0].shape)

        self.w2 = torch._utils._flatten_dense_tensors(w2)
        w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
        for data, param in zip(w2s, w2):
            param.data = data

        self.w2 = self.w2.view(len(w2), *w2s[0].shape)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
150
        num_tokens, hidden_dim = hidden_states.shape
151
152
153
        hidden_states = hidden_states.view(-1, hidden_dim)
        if self.config.n_shared_experts is not None:
            shared_output = self.shared_experts(hidden_states)
154
        # router_logits: (num_tokens, n_experts)
155
156
157
158
        router_logits, _ = self.gate(hidden_states)
        final_hidden_states = fused_moe(hidden_states,
                                        self.w1,
                                        self.w2,
159
160
161
                                        router_logits,
                                        self.top_k,
                                        renormalize=self.config.norm_topk_prob,
162
163
164
165
166
167
168
                                        inplace=True)

        if self.config.n_shared_experts is not None:
            final_hidden_states = final_hidden_states + shared_output
        final_hidden_states = tensor_model_parallel_all_reduce(
            final_hidden_states)

169
        return final_hidden_states.view(num_tokens, hidden_dim)
170
171
172
173
174
175
176
177
178
179
180
181


class DeepseekAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
182
        cache_config: Optional[CacheConfig] = None,
183
        quant_config: Optional[QuantizationConfig] = None,
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
214
            quant_config=quant_config,
215
216
217
218
219
220
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
221
            quant_config=quant_config,
222
223
224
225
226
227
228
229
230
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
231
232
233
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
234
235
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config)
236
237
238
239
240

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
241
242
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
243
244
245
246
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
247
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
248
249
250
251
252
253
254
255
256
257
        output, _ = self.o_proj(attn_output)
        return output


class DeepseekDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_idx: int,
258
        cache_config: Optional[CacheConfig] = None,
259
        quant_config: Optional[QuantizationConfig] = None,
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = DeepseekAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
274
            cache_config=cache_config,
275
            quant_config=quant_config,
276
        )
277
278
279
        if (config.n_routed_experts is not None
                and layer_idx >= config.first_k_dense_replace
                and layer_idx % config.moe_layer_freq == 0):
280
            self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
281
282
283
284
285
        else:
            self.mlp = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
286
                quant_config=quant_config,
287
288
289
290
291
292
293
294
295
296
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
297
298
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
299
300
301
302
303
304
305
306
307
308
309
310
311
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
312
            attn_metadata=attn_metadata,
313
314
315
316
317
318
319
320
321
322
323
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class DeepseekModel(nn.Module):

324
325
    fall_back_to_pt_during_load = False

326
327
328
    def __init__(
        self,
        config: PretrainedConfig,
329
        cache_config: Optional[CacheConfig] = None,
330
        quant_config: Optional[QuantizationConfig] = None,
331
332
333
334
335
336
337
338
339
340
    ) -> None:
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
341
342
343
344
            DeepseekDecoderLayer(config,
                                 layer_idx,
                                 cache_config,
                                 quant_config=quant_config)
345
346
347
348
349
350
351
352
            for layer_idx in range(config.num_hidden_layers)
        ])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
353
354
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
355
356
357
358
359
360
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(positions, hidden_states,
361
                                            kv_caches[i], attn_metadata,
362
363
364
365
366
367
368
369
370
371
                                            residual)
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class DeepseekForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
372
        cache_config: Optional[CacheConfig] = None,
373
        quant_config: Optional[QuantizationConfig] = None,
374
375
376
    ) -> None:
        super().__init__()
        self.config = config
377
        self.quant_config = quant_config
378
        self.model = DeepseekModel(config, cache_config, quant_config)
379
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
380
381
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
382
383
384
385
386

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
387
388
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
389
390
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
391
                                   attn_metadata)
392
393
        return hidden_states

394
395
396
397
398
399
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

400
401
    def sample(
        self,
402
        logits: Optional[torch.Tensor],
403
404
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
405
        next_tokens = self.sampler(logits, sampling_metadata)
406
407
        return next_tokens

408
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
409
410
411
412
413
414
415
416
417
418
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
419
        for name, loaded_weight in weights:
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
                if (("mlp.experts." in name or "mlp.shared_experts." in name)
                        and name not in params_dict):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
                if (("mlp.experts." in name or "mlp.shared_experts." in name)
                        and name not in params_dict):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)