deepseek.py 20.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
25
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
26
27
28
29
30

import torch
from torch import nn
from transformers import PretrainedConfig

31
from vllm.attention import Attention, AttentionMetadata
32
from vllm.config import CacheConfig, VllmConfig
33
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
34
35
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_reduce)
36
37
38
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
39
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
40
                                               QKVParallelLinear,
41
                                               ReplicatedLinear,
42
                                               RowParallelLinear)
43
from vllm.model_executor.layers.logits_processor import LogitsProcessor
44
from vllm.model_executor.layers.quantization import QuantizationConfig
45
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
46
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
47
from vllm.model_executor.layers.vocab_parallel_embedding import (
48
    ParallelLMHead, VocabParallelEmbedding)
49
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
50
from vllm.model_executor.sampling_metadata import SamplingMetadata
51
from vllm.sequence import IntermediateTensors
52

53
from .interfaces import SupportsPP
54
from .utils import (extract_layer_index, is_pp_missing_parameter,
55
56
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
57

58
59
60
61
62
63
64
65

class DeepseekMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
66
        quant_config: Optional[QuantizationConfig] = None,
67
        reduce_results: bool = True,
68
        prefix: str = "",
69
70
71
72
73
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
74
            quant_config=quant_config)
75
76
77
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
78
                                           quant_config=quant_config,
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
                                           reduce_results=reduce_results)
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekMoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
97
        quant_config: Optional[QuantizationConfig] = None,
98
        prefix: str = "",
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    ):
        super().__init__()
        self.config = config
        self.rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.n_routed_experts = config.n_routed_experts
        self.top_k = config.num_experts_per_tok
        if self.tp_size > self.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {self.n_routed_experts}.")

        self.experts = nn.ModuleList([
            DeepseekMLP(hidden_size=config.hidden_size,
                        intermediate_size=config.moe_intermediate_size,
                        hidden_act=config.hidden_act,
115
                        quant_config=quant_config,
116
117
118
119
120
121
122
123
                        reduce_results=False)
            for idx in range(self.n_routed_experts)
        ])
        self.pack_params()

        self.gate = ReplicatedLinear(config.hidden_size,
                                     self.n_routed_experts,
                                     bias=False,
124
                                     quant_config=None)
125
126

        if config.n_shared_experts is not None:
127
128
            intermediate_size = (config.moe_intermediate_size *
                                 config.n_shared_experts)
129
130
131
132
            self.shared_experts = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
133
                quant_config=quant_config,
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
                reduce_results=False,
            )

    def pack_params(self):
        w1 = []
        w2 = []
        for expert in self.experts:
            w1.append(expert.gate_up_proj.weight)
            w2.append(expert.down_proj.weight)
        self.w1 = torch._utils._flatten_dense_tensors(w1)
        w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
        for data, param in zip(w1s, w1):
            param.data = data
        self.w1 = self.w1.view(len(w1), *w1s[0].shape)

        self.w2 = torch._utils._flatten_dense_tensors(w2)
        w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
        for data, param in zip(w2s, w2):
            param.data = data

        self.w2 = self.w2.view(len(w2), *w2s[0].shape)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
157
        num_tokens, hidden_dim = hidden_states.shape
158
159
160
        hidden_states = hidden_states.view(-1, hidden_dim)
        if self.config.n_shared_experts is not None:
            shared_output = self.shared_experts(hidden_states)
161
        # router_logits: (num_tokens, n_experts)
162
163
164
165
        router_logits, _ = self.gate(hidden_states)
        final_hidden_states = fused_moe(hidden_states,
                                        self.w1,
                                        self.w2,
166
167
168
                                        router_logits,
                                        self.top_k,
                                        renormalize=self.config.norm_topk_prob,
169
170
171
172
173
174
175
                                        inplace=True)

        if self.config.n_shared_experts is not None:
            final_hidden_states = final_hidden_states + shared_output
        final_hidden_states = tensor_model_parallel_all_reduce(
            final_hidden_states)

176
        return final_hidden_states.view(num_tokens, hidden_dim)
177
178
179
180
181
182
183
184
185
186
187
188


class DeepseekAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
189
        cache_config: Optional[CacheConfig] = None,
190
        quant_config: Optional[QuantizationConfig] = None,
191
        prefix: str = "",
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
222
            quant_config=quant_config,
223
224
225
226
227
228
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
229
            quant_config=quant_config,
230
231
232
233
234
235
236
237
238
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
239
240
241
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
242
                              num_kv_heads=self.num_kv_heads,
243
                              cache_config=cache_config,
244
245
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
246
247
248
249
250

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
251
252
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
253
254
255
256
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
257
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
258
259
260
261
262
263
264
265
266
        output, _ = self.o_proj(attn_output)
        return output


class DeepseekDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
267
        cache_config: Optional[CacheConfig] = None,
268
        quant_config: Optional[QuantizationConfig] = None,
269
        prefix: str = "",
270
271
    ) -> None:
        super().__init__()
272
        layer_idx = extract_layer_index(prefix)
273
274
275
276
277
278
279
280
281
282
283
284
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = DeepseekAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
285
            cache_config=cache_config,
286
            quant_config=quant_config,
287
            prefix=f"{prefix}.self_attn",
288
        )
289
290
291
        if (config.n_routed_experts is not None
                and layer_idx >= config.first_k_dense_replace
                and layer_idx % config.moe_layer_freq == 0):
292
293
294
            self.mlp = DeepseekMoE(config=config,
                                   quant_config=quant_config,
                                   prefix=f"{prefix}.mlp")
295
296
297
298
299
        else:
            self.mlp = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
300
                quant_config=quant_config,
301
                prefix=f"{prefix}.mlp",
302
303
304
305
306
307
308
309
310
311
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
312
313
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
314
315
316
317
318
319
320
321
322
323
324
325
326
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
327
            attn_metadata=attn_metadata,
328
329
330
331
332
333
334
335
336
337
338
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class DeepseekModel(nn.Module):

339
340
    fall_back_to_pt_during_load = False

341
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
342
        super().__init__()
343
344
345
346
347

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

348
349
350
351
352
353
354
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
355
356
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
357
358
359
            lambda prefix: DeepseekDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
360
            prefix=f"{prefix}.layers")
361
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
362
363
364
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
365

366
367
368
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

369
370
371
372
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
373
374
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
375
        intermediate_tensors: Optional[IntermediateTensors],
376
        inputs_embeds: Optional[torch.Tensor] = None,
377
378
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
379
380
381
382
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
383
384
385
386
387
            residual = None
        else:
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
388
389
            layer = self.layers[i]
            hidden_states, residual = layer(positions, hidden_states,
390
391
392
393
394
395
396
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
397
398
399
400
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


401
class DeepseekForCausalLM(nn.Module, SupportsPP):
402

403
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
404
        super().__init__()
405
406
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
407
        self.config = config
408
        self.quant_config = quant_config
409
410
        self.model = DeepseekModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "model"))
411
412
413
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
414
415
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
416
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
417
        self.sampler = get_sampler()
418
419
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
420

421
422
423
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

424
425
426
427
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
428
429
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
430
        intermediate_tensors: Optional[IntermediateTensors] = None,
431
        inputs_embeds: Optional[torch.Tensor] = None,
432
    ) -> Union[torch.Tensor, IntermediateTensors]:
433
        hidden_states = self.model(input_ids, positions, kv_caches,
434
435
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
436
437
        return hidden_states

438
439
440
441
442
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
443
        logits = self.logits_processor(self.lm_head, hidden_states,
444
445
446
                                       sampling_metadata)
        return logits

447
448
    def sample(
        self,
449
        logits: Optional[torch.Tensor],
450
451
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
452
        next_tokens = self.sampler(logits, sampling_metadata)
453
454
        return next_tokens

455
456
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
457
458
459
460
461
462
463
464
465
466
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
467
        loaded_params: Set[str] = set()
468
        for name, loaded_weight in weights:
469
470
471
472
473
474
475
476
477
478
479
480
481
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
                if (("mlp.experts." in name or "mlp.shared_experts." in name)
                        and name not in params_dict):
                    continue
482
483
                if is_pp_missing_parameter(name, self):
                    continue
484
485
486
487
488
489
490
491
492
493
494
495
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
                if (("mlp.experts." in name or "mlp.shared_experts." in name)
                        and name not in params_dict):
                    continue
496
497
                if is_pp_missing_parameter(name, self):
                    continue
498
499
500
501
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
502
503
            loaded_params.add(name)
        return loaded_params