deepseek.py 20.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
23
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
24
25
26
27
28

import torch
from torch import nn
from transformers import PretrainedConfig

29
from vllm.attention import Attention, AttentionMetadata
30
from vllm.config import CacheConfig, VllmConfig
31
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
32
33
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_reduce)
34
35
36
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
37
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
38
                                               QKVParallelLinear,
39
                                               ReplicatedLinear,
40
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.quantization import QuantizationConfig
43
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
44
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
45
from vllm.model_executor.layers.vocab_parallel_embedding import (
46
    ParallelLMHead, VocabParallelEmbedding)
47
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
48
from vllm.model_executor.sampling_metadata import SamplingMetadata
49
from vllm.sequence import IntermediateTensors
50

51
from .interfaces import SupportsPP
52
from .utils import (extract_layer_index, is_pp_missing_parameter,
53
54
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
55

56
57
58
59
60
61
62
63

class DeepseekMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
64
        quant_config: Optional[QuantizationConfig] = None,
65
        reduce_results: bool = True,
66
        prefix: str = "",
67
68
69
70
71
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
72
            quant_config=quant_config)
73
74
75
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
76
                                           quant_config=quant_config,
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
                                           reduce_results=reduce_results)
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekMoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
95
        quant_config: Optional[QuantizationConfig] = None,
96
        prefix: str = "",
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    ):
        super().__init__()
        self.config = config
        self.rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.n_routed_experts = config.n_routed_experts
        self.top_k = config.num_experts_per_tok
        if self.tp_size > self.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {self.n_routed_experts}.")

        self.experts = nn.ModuleList([
            DeepseekMLP(hidden_size=config.hidden_size,
                        intermediate_size=config.moe_intermediate_size,
                        hidden_act=config.hidden_act,
113
                        quant_config=quant_config,
114
115
116
117
118
119
120
121
                        reduce_results=False)
            for idx in range(self.n_routed_experts)
        ])
        self.pack_params()

        self.gate = ReplicatedLinear(config.hidden_size,
                                     self.n_routed_experts,
                                     bias=False,
122
                                     quant_config=None)
123
124

        if config.n_shared_experts is not None:
125
126
            intermediate_size = (config.moe_intermediate_size *
                                 config.n_shared_experts)
127
128
129
130
            self.shared_experts = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
131
                quant_config=quant_config,
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
                reduce_results=False,
            )

    def pack_params(self):
        w1 = []
        w2 = []
        for expert in self.experts:
            w1.append(expert.gate_up_proj.weight)
            w2.append(expert.down_proj.weight)
        self.w1 = torch._utils._flatten_dense_tensors(w1)
        w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
        for data, param in zip(w1s, w1):
            param.data = data
        self.w1 = self.w1.view(len(w1), *w1s[0].shape)

        self.w2 = torch._utils._flatten_dense_tensors(w2)
        w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
        for data, param in zip(w2s, w2):
            param.data = data

        self.w2 = self.w2.view(len(w2), *w2s[0].shape)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
155
        num_tokens, hidden_dim = hidden_states.shape
156
157
158
        hidden_states = hidden_states.view(-1, hidden_dim)
        if self.config.n_shared_experts is not None:
            shared_output = self.shared_experts(hidden_states)
159
        # router_logits: (num_tokens, n_experts)
160
161
162
163
        router_logits, _ = self.gate(hidden_states)
        final_hidden_states = fused_moe(hidden_states,
                                        self.w1,
                                        self.w2,
164
165
166
                                        router_logits,
                                        self.top_k,
                                        renormalize=self.config.norm_topk_prob,
167
168
169
170
171
172
173
                                        inplace=True)

        if self.config.n_shared_experts is not None:
            final_hidden_states = final_hidden_states + shared_output
        final_hidden_states = tensor_model_parallel_all_reduce(
            final_hidden_states)

174
        return final_hidden_states.view(num_tokens, hidden_dim)
175
176
177
178
179
180
181
182
183
184
185
186


class DeepseekAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
187
        cache_config: Optional[CacheConfig] = None,
188
        quant_config: Optional[QuantizationConfig] = None,
189
        prefix: str = "",
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
220
            quant_config=quant_config,
221
222
223
224
225
226
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
227
            quant_config=quant_config,
228
229
230
231
232
233
234
235
236
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
237
238
239
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
240
                              num_kv_heads=self.num_kv_heads,
241
                              cache_config=cache_config,
242
243
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
244
245
246
247
248

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
249
250
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
251
252
253
254
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
255
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
256
257
258
259
260
261
262
263
264
        output, _ = self.o_proj(attn_output)
        return output


class DeepseekDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
265
        cache_config: Optional[CacheConfig] = None,
266
        quant_config: Optional[QuantizationConfig] = None,
267
        prefix: str = "",
268
269
    ) -> None:
        super().__init__()
270
        layer_idx = extract_layer_index(prefix)
271
272
273
274
275
276
277
278
279
280
281
282
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = DeepseekAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
283
            cache_config=cache_config,
284
            quant_config=quant_config,
285
            prefix=f"{prefix}.self_attn",
286
        )
287
288
289
        if (config.n_routed_experts is not None
                and layer_idx >= config.first_k_dense_replace
                and layer_idx % config.moe_layer_freq == 0):
290
291
292
            self.mlp = DeepseekMoE(config=config,
                                   quant_config=quant_config,
                                   prefix=f"{prefix}.mlp")
293
294
295
296
297
        else:
            self.mlp = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
298
                quant_config=quant_config,
299
                prefix=f"{prefix}.mlp",
300
301
302
303
304
305
306
307
308
309
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
310
311
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
312
313
314
315
316
317
318
319
320
321
322
323
324
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
325
            attn_metadata=attn_metadata,
326
327
328
329
330
331
332
333
334
335
336
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class DeepseekModel(nn.Module):

337
338
    fall_back_to_pt_during_load = False

339
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
340
        super().__init__()
341
342
343
344
345

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

346
347
348
349
350
351
352
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
353
354
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
355
356
357
            lambda prefix: DeepseekDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
358
            prefix=f"{prefix}.layers")
359
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
360
361
362
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
363

364
365
366
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

367
368
369
370
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
371
372
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
373
        intermediate_tensors: Optional[IntermediateTensors],
374
        inputs_embeds: Optional[torch.Tensor] = None,
375
376
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
377
378
379
380
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
381
382
383
384
385
            residual = None
        else:
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
386
387
            layer = self.layers[i]
            hidden_states, residual = layer(positions, hidden_states,
388
389
390
391
392
393
394
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
395
396
397
398
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


399
class DeepseekForCausalLM(nn.Module, SupportsPP):
400

401
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
402
        super().__init__()
403
404
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
405
        self.config = config
406
        self.quant_config = quant_config
407
408
        self.model = DeepseekModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "model"))
409
410
411
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
412
413
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
414
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
415
        self.sampler = get_sampler()
416
417
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
418

419
420
421
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

422
423
424
425
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
426
427
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
428
        intermediate_tensors: Optional[IntermediateTensors] = None,
429
        inputs_embeds: Optional[torch.Tensor] = None,
430
    ) -> Union[torch.Tensor, IntermediateTensors]:
431
        hidden_states = self.model(input_ids, positions, kv_caches,
432
433
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
434
435
        return hidden_states

436
437
438
439
440
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
441
        logits = self.logits_processor(self.lm_head, hidden_states,
442
443
444
                                       sampling_metadata)
        return logits

445
446
    def sample(
        self,
447
        logits: Optional[torch.Tensor],
448
449
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
450
        next_tokens = self.sampler(logits, sampling_metadata)
451
452
        return next_tokens

453
454
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
455
456
457
458
459
460
461
462
463
464
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
465
        loaded_params: Set[str] = set()
466
        for name, loaded_weight in weights:
467
468
469
470
471
472
473
474
475
476
477
478
479
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
                if (("mlp.experts." in name or "mlp.shared_experts." in name)
                        and name not in params_dict):
                    continue
480
481
                if is_pp_missing_parameter(name, self):
                    continue
482
483
484
485
486
487
488
489
490
491
492
493
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
                if (("mlp.experts." in name or "mlp.shared_experts." in name)
                        and name not in params_dict):
                    continue
494
495
                if is_pp_missing_parameter(name, self):
                    continue
496
497
498
499
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
500
501
            loaded_params.add(name)
        return loaded_params