deepseek.py 19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
24
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
25
26
27
28
29

import torch
from torch import nn
from transformers import PretrainedConfig

30
from vllm.attention import Attention, AttentionMetadata
31
from vllm.config import CacheConfig
32
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
33
34
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_reduce)
35
36
37
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
38
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
39
                                               QKVParallelLinear,
40
                                               ReplicatedLinear,
41
                                               RowParallelLinear)
42
from vllm.model_executor.layers.logits_processor import LogitsProcessor
43
from vllm.model_executor.layers.quantization import QuantizationConfig
44
from vllm.model_executor.layers.rotary_embedding import get_rope
45
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
46
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
    ParallelLMHead, VocabParallelEmbedding)
48
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
49
from vllm.model_executor.sampling_metadata import SamplingMetadata
50
from vllm.sequence import IntermediateTensors
51

52
53
54
55
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)

56
57
58
59
60
61
62
63

class DeepseekMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
64
        quant_config: Optional[QuantizationConfig] = None,
65
66
67
68
69
70
        reduce_results: bool = True,
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
71
            quant_config=quant_config)
72
73
74
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
75
                                           quant_config=quant_config,
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
                                           reduce_results=reduce_results)
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekMoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
94
        quant_config: Optional[QuantizationConfig] = None,
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    ):
        super().__init__()
        self.config = config
        self.rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.n_routed_experts = config.n_routed_experts
        self.top_k = config.num_experts_per_tok
        if self.tp_size > self.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {self.n_routed_experts}.")

        self.experts = nn.ModuleList([
            DeepseekMLP(hidden_size=config.hidden_size,
                        intermediate_size=config.moe_intermediate_size,
                        hidden_act=config.hidden_act,
111
                        quant_config=quant_config,
112
113
114
115
116
117
118
119
                        reduce_results=False)
            for idx in range(self.n_routed_experts)
        ])
        self.pack_params()

        self.gate = ReplicatedLinear(config.hidden_size,
                                     self.n_routed_experts,
                                     bias=False,
120
                                     quant_config=None)
121
122

        if config.n_shared_experts is not None:
123
124
            intermediate_size = (config.moe_intermediate_size *
                                 config.n_shared_experts)
125
126
127
128
            self.shared_experts = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
129
                quant_config=quant_config,
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
                reduce_results=False,
            )

    def pack_params(self):
        w1 = []
        w2 = []
        for expert in self.experts:
            w1.append(expert.gate_up_proj.weight)
            w2.append(expert.down_proj.weight)
        self.w1 = torch._utils._flatten_dense_tensors(w1)
        w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
        for data, param in zip(w1s, w1):
            param.data = data
        self.w1 = self.w1.view(len(w1), *w1s[0].shape)

        self.w2 = torch._utils._flatten_dense_tensors(w2)
        w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
        for data, param in zip(w2s, w2):
            param.data = data

        self.w2 = self.w2.view(len(w2), *w2s[0].shape)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
153
        num_tokens, hidden_dim = hidden_states.shape
154
155
156
        hidden_states = hidden_states.view(-1, hidden_dim)
        if self.config.n_shared_experts is not None:
            shared_output = self.shared_experts(hidden_states)
157
        # router_logits: (num_tokens, n_experts)
158
159
160
161
        router_logits, _ = self.gate(hidden_states)
        final_hidden_states = fused_moe(hidden_states,
                                        self.w1,
                                        self.w2,
162
163
164
                                        router_logits,
                                        self.top_k,
                                        renormalize=self.config.norm_topk_prob,
165
166
167
168
169
170
171
                                        inplace=True)

        if self.config.n_shared_experts is not None:
            final_hidden_states = final_hidden_states + shared_output
        final_hidden_states = tensor_model_parallel_all_reduce(
            final_hidden_states)

172
        return final_hidden_states.view(num_tokens, hidden_dim)
173
174
175
176
177
178
179
180
181
182
183
184


class DeepseekAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
185
        cache_config: Optional[CacheConfig] = None,
186
        quant_config: Optional[QuantizationConfig] = None,
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
217
            quant_config=quant_config,
218
219
220
221
222
223
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
224
            quant_config=quant_config,
225
226
227
228
229
230
231
232
233
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
234
235
236
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
237
                              num_kv_heads=self.num_kv_heads,
238
239
                              cache_config=cache_config,
                              quant_config=quant_config)
240
241
242
243
244

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
245
246
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
247
248
249
250
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
251
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
252
253
254
255
256
257
258
259
260
261
        output, _ = self.o_proj(attn_output)
        return output


class DeepseekDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_idx: int,
262
        cache_config: Optional[CacheConfig] = None,
263
        quant_config: Optional[QuantizationConfig] = None,
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = DeepseekAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
278
            cache_config=cache_config,
279
            quant_config=quant_config,
280
        )
281
282
283
        if (config.n_routed_experts is not None
                and layer_idx >= config.first_k_dense_replace
                and layer_idx % config.moe_layer_freq == 0):
284
            self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
285
286
287
288
289
        else:
            self.mlp = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
290
                quant_config=quant_config,
291
292
293
294
295
296
297
298
299
300
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
301
302
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
303
304
305
306
307
308
309
310
311
312
313
314
315
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
316
            attn_metadata=attn_metadata,
317
318
319
320
321
322
323
324
325
326
327
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class DeepseekModel(nn.Module):

328
329
    fall_back_to_pt_during_load = False

330
331
332
    def __init__(
        self,
        config: PretrainedConfig,
333
        cache_config: Optional[CacheConfig] = None,
334
        quant_config: Optional[QuantizationConfig] = None,
335
        prefix: str = "",
336
337
338
339
340
341
342
343
344
    ) -> None:
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
345
346
347
348
349
350
351
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: DeepseekDecoderLayer(config,
                                                int(prefix.split(".")[-1]),
                                                cache_config,
                                                quant_config=quant_config),
            prefix=f"{prefix}.layers")
352
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
353
354
355
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
356
357
358
359
360

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
361
362
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
363
364
365
366
367
368
369
370
371
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_tokens(input_ids)
            residual = None
        else:
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
372
373
            layer = self.layers[i]
            hidden_states, residual = layer(positions, hidden_states,
374
375
376
377
378
379
380
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
381
382
383
384
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


385
class DeepseekForCausalLM(nn.Module, SupportsPP):
386
387
388
389

    def __init__(
        self,
        config: PretrainedConfig,
390
        cache_config: Optional[CacheConfig] = None,
391
        quant_config: Optional[QuantizationConfig] = None,
392
393
394
    ) -> None:
        super().__init__()
        self.config = config
395
        self.quant_config = quant_config
396
        self.model = DeepseekModel(config, cache_config, quant_config)
397
398
399
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
400
401
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
402
403
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
404
405
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
406
407
408
409
410

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
411
412
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
413
        intermediate_tensors: Optional[IntermediateTensors] = None,
414
    ) -> Union[torch.Tensor, IntermediateTensors]:
415
        hidden_states = self.model(input_ids, positions, kv_caches,
416
                                   attn_metadata, intermediate_tensors)
417
418
        return hidden_states

419
420
421
422
423
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
424
        logits = self.logits_processor(self.lm_head, hidden_states,
425
426
427
                                       sampling_metadata)
        return logits

428
429
    def sample(
        self,
430
        logits: Optional[torch.Tensor],
431
432
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
433
        next_tokens = self.sampler(logits, sampling_metadata)
434
435
        return next_tokens

436
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
437
438
439
440
441
442
443
444
445
446
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
447
        for name, loaded_weight in weights:
448
449
450
451
452
453
454
455
456
457
458
459
460
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
                if (("mlp.experts." in name or "mlp.shared_experts." in name)
                        and name not in params_dict):
                    continue
461
462
                if is_pp_missing_parameter(name, self):
                    continue
463
464
465
466
467
468
469
470
471
472
473
474
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
                if (("mlp.experts." in name or "mlp.shared_experts." in name)
                        and name not in params_dict):
                    continue
475
476
                if is_pp_missing_parameter(name, self):
                    continue
477
478
479
480
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)