deepseek.py 18.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
26

27
from collections.abc import Iterable
28
from itertools import islice
29
from typing import Any, Optional, Union
30
31
32
33
34

import torch
from torch import nn
from transformers import PretrainedConfig

35
from vllm.attention import Attention
36
from vllm.config import CacheConfig, VllmConfig
37
38
39
40
41
42
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_reduce,
)
43
from vllm.model_executor.layers.activation import SiluAndMul
44
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
45
from vllm.model_executor.layers.layernorm import RMSNorm
46
47
48
49
50
51
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
52
from vllm.model_executor.layers.logits_processor import LogitsProcessor
53
from vllm.model_executor.layers.quantization import QuantizationConfig
54
from vllm.model_executor.layers.rotary_embedding import get_rope
55
from vllm.model_executor.layers.vocab_parallel_embedding import (
56
57
58
    ParallelLMHead,
    VocabParallelEmbedding,
)
59
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
60
from vllm.sequence import IntermediateTensors
61

62
from .interfaces import SupportsLoRA, SupportsPP
63
64
65
66
67
68
69
70
from .utils import (
    AutoWeightsLoader,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
71

72
73
74
75
76
77
78

class DeepseekMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
79
        quant_config: Optional[QuantizationConfig] = None,
80
        reduce_results: bool = True,
81
        prefix: str = "",
82
83
84
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
85
86
87
88
89
            hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
90
            bias=False,
91
92
93
            quant_config=quant_config,
            reduce_results=reduce_results,
        )
94
        if hidden_act != "silu":
95
96
97
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
98
99
100
101
102
103
104
105
106
107
108
109
110
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekMoE(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
111
        quant_config: Optional[QuantizationConfig] = None,
112
        prefix: str = "",
113
114
115
116
117
118
119
120
121
122
    ):
        super().__init__()
        self.config = config
        self.rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.n_routed_experts = config.n_routed_experts
        self.top_k = config.num_experts_per_tok
        if self.tp_size > self.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
                f"the number of experts {self.n_routed_experts}."
            )

        self.experts = nn.ModuleList(
            [
                DeepseekMLP(
                    hidden_size=config.hidden_size,
                    intermediate_size=config.moe_intermediate_size,
                    hidden_act=config.hidden_act,
                    quant_config=quant_config,
                    reduce_results=False,
                )
                for idx in range(self.n_routed_experts)
            ]
        )
138
139
        self.pack_params()

140
141
142
        self.gate = ReplicatedLinear(
            config.hidden_size, self.n_routed_experts, bias=False, quant_config=None
        )
143
144

        if config.n_shared_experts is not None:
145
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
146
147
148
149
            self.shared_experts = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
150
                quant_config=quant_config,
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
                reduce_results=False,
            )

    def pack_params(self):
        w1 = []
        w2 = []
        for expert in self.experts:
            w1.append(expert.gate_up_proj.weight)
            w2.append(expert.down_proj.weight)
        self.w1 = torch._utils._flatten_dense_tensors(w1)
        w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
        for data, param in zip(w1s, w1):
            param.data = data
        self.w1 = self.w1.view(len(w1), *w1s[0].shape)

        self.w2 = torch._utils._flatten_dense_tensors(w2)
        w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
        for data, param in zip(w2s, w2):
            param.data = data

        self.w2 = self.w2.view(len(w2), *w2s[0].shape)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
174
        num_tokens, hidden_dim = hidden_states.shape
175
176
177
        hidden_states = hidden_states.view(-1, hidden_dim)
        if self.config.n_shared_experts is not None:
            shared_output = self.shared_experts(hidden_states)
178
        # router_logits: (num_tokens, n_experts)
179
        router_logits, _ = self.gate(hidden_states)
180
181
182
183
184

        topk_weights, topk_ids, _ = fused_topk(
            hidden_states,
            router_logits,
            self.top_k,
185
186
            renormalize=self.config.norm_topk_prob,
        )
187

188
189
190
        final_hidden_states = fused_experts(
            hidden_states, self.w1, self.w2, topk_weights, topk_ids, inplace=True
        )
191
192
193

        if self.config.n_shared_experts is not None:
            final_hidden_states = final_hidden_states + shared_output
194
        final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
195

196
        return final_hidden_states.view(num_tokens, hidden_dim)
197
198
199
200
201
202
203
204
205


class DeepseekAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
206
        rope_scaling: Optional[dict[str, Any]] = None,
207
        max_position_embeddings: int = 8192,
208
        cache_config: Optional[CacheConfig] = None,
209
        quant_config: Optional[QuantizationConfig] = None,
210
        prefix: str = "",
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
241
            quant_config=quant_config,
242
243
244
245
246
247
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
248
            quant_config=quant_config,
249
250
251
252
253
254
255
256
257
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
258
259
260
261
262
263
264
265
266
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
267
268
269
270
271
272
273
274
275

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
276
        attn_output = self.attn(q, k, v)
277
278
279
280
281
282
283
284
        output, _ = self.o_proj(attn_output)
        return output


class DeepseekDecoderLayer(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
285
        cache_config: Optional[CacheConfig] = None,
286
        quant_config: Optional[QuantizationConfig] = None,
287
        prefix: str = "",
288
289
    ) -> None:
        super().__init__()
290
        layer_idx = extract_layer_index(prefix)
291
292
293
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
294
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
295
296
297
298
299
300
301
        self.self_attn = DeepseekAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
302
            cache_config=cache_config,
303
            quant_config=quant_config,
304
            prefix=f"{prefix}.self_attn",
305
        )
306
307
308
309
310
311
312
313
        if (
            config.n_routed_experts is not None
            and layer_idx >= config.first_k_dense_replace
            and layer_idx % config.moe_layer_freq == 0
        ):
            self.mlp = DeepseekMoE(
                config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
            )
314
315
316
317
318
        else:
            self.mlp = DeepseekMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
319
                quant_config=quant_config,
320
                prefix=f"{prefix}.mlp",
321
            )
322
323
324
325
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
326
327
328
329
330
331
332
333
334
335
336
337

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
338
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
339
340
341
342
343
344
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
345
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
346
347
348
349
350
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class DeepseekModel(nn.Module):
351
352
    fall_back_to_pt_during_load = False

353
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
354
        super().__init__()
355
356
357
358
359

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

360
361
362
363
364
365
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
366
367
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
368
369
370
            lambda prefix: DeepseekDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
371
372
            prefix=f"{prefix}.layers",
        )
373
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
374
375
376
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
377

378
379
380
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

381
382
383
384
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
385
        intermediate_tensors: Optional[IntermediateTensors],
386
        inputs_embeds: Optional[torch.Tensor] = None,
387
388
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
389
390
391
392
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
393
394
395
396
            residual = None
        else:
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
397
        for layer in islice(self.layers, self.start_layer, self.end_layer):
398
            hidden_states, residual = layer(positions, hidden_states, residual)
399
        if not get_pp_group().is_last_rank:
400
401
402
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
403
404
405
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

406
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
407
408
409
410
411
412
413
414
415
416
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
417
        loaded_params: set[str] = set()
418
419
420
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
421
            for param_name, weight_name, shard_id in stacked_params_mapping:
422
423
424
425
426
427
428
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
429
430
431
                if (
                    "mlp.experts." in name or "mlp.shared_experts." in name
                ) and name not in params_dict:
432
433
434
435
436
437
438
439
440
441
442
443
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
444
445
446
                if (
                    "mlp.experts." in name or "mlp.shared_experts." in name
                ) and name not in params_dict:
447
448
449
450
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
451
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
452
453
454
455
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

456

457
458
459
460
461
class DeepseekForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
    }
462

463
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
464
        super().__init__()
465
466
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
467
        self.config = config
468
        self.quant_config = quant_config
469
470
471
        self.model = DeepseekModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
472
473
474
475
476
477
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
478
479
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
480
        self.logits_processor = LogitsProcessor(config.vocab_size)
481
        self.make_empty_intermediate_tensors = (
482
483
            self.model.make_empty_intermediate_tensors
        )
484

485
486
487
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

488
489
490
491
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
492
        intermediate_tensors: Optional[IntermediateTensors] = None,
493
        inputs_embeds: Optional[torch.Tensor] = None,
494
    ) -> Union[torch.Tensor, IntermediateTensors]:
495
496
497
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
498
499
        return hidden_states

500
501
502
503
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
504
        logits = self.logits_processor(self.lm_head, hidden_states)
505
506
        return logits

507
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
508
        loader = AutoWeightsLoader(self)
509
        return loader.load_weights(weights)