mixtral.py 21.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Pierre Stock's avatar
Pierre Stock committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
王敏's avatar
王敏 committed
26
27
import os
import re
zhuwenwen's avatar
zhuwenwen committed
28
29
from typing import Iterable, Optional, Union

30
from collections.abc import Iterable
31
from itertools import islice
32
from typing import Optional, Union
Pierre Stock's avatar
Pierre Stock committed
33
34
35

import torch
from torch import nn
36
from transformers import MixtralConfig
Pierre Stock's avatar
Pierre Stock committed
37

38
from vllm.attention import Attention
39
from vllm.compilation.decorators import support_torch_compile
40
from vllm.config import CacheConfig, VllmConfig
41
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
42
from vllm.model_executor.layers.fused_moe import FusedMoE
Pierre Stock's avatar
Pierre Stock committed
43
from vllm.model_executor.layers.layernorm import RMSNorm
44
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Philipp Moritz's avatar
Philipp Moritz committed
45
                                               ReplicatedLinear,
Pierre Stock's avatar
Pierre Stock committed
46
                                               RowParallelLinear)
47
from vllm.model_executor.layers.logits_processor import LogitsProcessor
48
from vllm.model_executor.layers.quantization import QuantizationConfig
49
from vllm.model_executor.layers.rotary_embedding import get_rope
Pierre Stock's avatar
Pierre Stock committed
50
from vllm.model_executor.layers.vocab_parallel_embedding import (
51
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
52
53
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
Pierre Stock's avatar
Pierre Stock committed
54
from vllm.model_executor.sampling_metadata import SamplingMetadata
55
from vllm.sequence import IntermediateTensors
王敏's avatar
王敏 committed
56
57
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
Pierre Stock's avatar
Pierre Stock committed
58

59
from .interfaces import SupportsLoRA, SupportsPP
60
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
61
62
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
63

Pierre Stock's avatar
Pierre Stock committed
64

Philipp Moritz's avatar
Philipp Moritz committed
65
66
67
68
69
70
71
72
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
73

74
75
76
77
78
79
80
    def __init__(self,
                 num_experts: int,
                 top_k: int,
                 hidden_size: int,
                 intermediate_size: int,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
81
                 tp_size: Optional[int] = None,
82
                 dp_size: Optional[int] = None,
83
                 prefix: str = ""):
84
        super().__init__()
Philipp Moritz's avatar
Philipp Moritz committed
85
        self.hidden_size = hidden_size
86

87
        # Gate always runs at half / full precision for now.
88

89
90
        self.gate = ReplicatedLinear(hidden_size,
                                     num_experts,
91
                                     bias=False,
92
                                     params_dtype=params_dtype,
93
94
                                     quant_config=None,
                                     prefix=f"{prefix}.gate")
95

96
97
98
99
100
101
102
103
        self.experts = FusedMoE(num_experts=num_experts,
                                top_k=top_k,
                                hidden_size=hidden_size,
                                intermediate_size=intermediate_size,
                                params_dtype=params_dtype,
                                reduce_results=True,
                                renormalize=True,
                                quant_config=quant_config,
104
                                tp_size=tp_size,
105
                                dp_size=dp_size,
106
                                prefix=f"{prefix}.experts")
107

108
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
109
110
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
Philipp Moritz's avatar
Philipp Moritz committed
111
        hidden_states = hidden_states.view(-1, self.hidden_size)
112
        # router_logits: (num_tokens, n_experts)
113
        router_logits, _ = self.gate(hidden_states)
114
        final_hidden_states = self.experts(hidden_states, router_logits)
115
        return final_hidden_states.view(orig_shape)
Pierre Stock's avatar
Pierre Stock committed
116
117
118
119


class MixtralAttention(nn.Module):

120
121
    def __init__(
        self,
122
        config: MixtralConfig,
123
124
125
126
127
128
129
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
130
        prefix: str = "",
131
    ) -> None:
Pierre Stock's avatar
Pierre Stock committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
148
        # MixtralConfig has an optional head_dim argument
149
150
151
        self.head_dim = getattr(config, "head_dim", None)
        if self.head_dim is None:
            self.head_dim = self.hidden_size // self.total_num_heads
Pierre Stock's avatar
Pierre Stock committed
152
153
154
155
156
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

157
        self.qkv_proj = QKVParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
158
159
160
161
162
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
163
            quant_config=quant_config,
164
            prefix=f"{prefix}.qkv_proj",
Pierre Stock's avatar
Pierre Stock committed
165
        )
166
        self.o_proj = RowParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
167
168
169
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
170
            quant_config=quant_config,
171
            prefix=f"{prefix}.o_proj",
Pierre Stock's avatar
Pierre Stock committed
172
173
174
175
176
177
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
178
            is_neox_style=True,
Pierre Stock's avatar
Pierre Stock committed
179
        )
180
181
182
183
184
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
185
186
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Pierre Stock's avatar
Pierre Stock committed
187
188
189
190
191
192

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
193
        qkv, _ = self.qkv_proj(hidden_states)
Pierre Stock's avatar
Pierre Stock committed
194
195
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
196
        attn_output = self.attn(q, k, v)
197
        output, _ = self.o_proj(attn_output)
Pierre Stock's avatar
Pierre Stock committed
198
199
200
201
202
203
204
        return output


class MixtralDecoderLayer(nn.Module):

    def __init__(
        self,
205
        config: MixtralConfig,
206
        cache_config: Optional[CacheConfig] = None,
207
        quant_config: Optional[QuantizationConfig] = None,
208
        prefix: str = "",
Pierre Stock's avatar
Pierre Stock committed
209
210
211
212
213
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
214
        self.self_attn = MixtralAttention(
215
            config=config,
Pierre Stock's avatar
Pierre Stock committed
216
217
218
219
220
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
221
            cache_config=cache_config,
222
223
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn")
Philipp Moritz's avatar
Philipp Moritz committed
224
225
226
227
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
228
            intermediate_size=config.intermediate_size,
229
230
            quant_config=quant_config,
            prefix=f"{prefix}.block_sparse_moe")
231
232
233
234
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
235
236
237
238

    def forward(
        self,
        positions: torch.Tensor,
239
240
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
241
    ) -> torch.Tensor:
242
243
244
245
246
247
248
249
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
Pierre Stock's avatar
Pierre Stock committed
250
            positions=positions,
251
            hidden_states=hidden_states,
Pierre Stock's avatar
Pierre Stock committed
252
253
        )

254
255
256
257
258
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual
Pierre Stock's avatar
Pierre Stock committed
259

260

261
@support_torch_compile
262
class MixtralModel(nn.Module):
Pierre Stock's avatar
Pierre Stock committed
263

264
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Pierre Stock's avatar
Pierre Stock committed
265
        super().__init__()
266
267
268
269
270
271

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

272
273
        self.config = config
        self.quant_config = quant_config
274
275
276
277
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
278
279

        self.embed_tokens = VocabParallelEmbedding(
280
            self.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
281
            config.hidden_size,
282
            org_num_embeddings=config.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
283
        )
284
285

        self.start_layer, self.end_layer, self.layers = make_layers(
286
287
288
289
290
            config.num_hidden_layers,
            lambda prefix: MixtralDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers")
291

292
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
293
294
295
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
296
297
298
299
300
301
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()

        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
Pierre Stock's avatar
Pierre Stock committed
302

303
304
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)
Pierre Stock's avatar
Pierre Stock committed
305
306
307
308
309

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
310
        intermediate_tensors: Optional[IntermediateTensors],
311
        inputs_embeds: Optional[torch.Tensor] = None,
312
    ) -> Union[torch.Tensor, IntermediateTensors]:
313
        if get_pp_group().is_first_rank:
314
315
316
317
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
318
319
320
321
322
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
323
        for layer in islice(self.layers, self.start_layer, self.end_layer):
324
            hidden_states, residual = layer(positions, hidden_states, residual)
325
326
327
328
329
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
330
331
332
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

333
334
335
336
337
338
339
340
341
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts)

342
343
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
Pierre Stock's avatar
Pierre Stock committed
344
345
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
346
347
348
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
Pierre Stock's avatar
Pierre Stock committed
349
        ]
350

Pierre Stock's avatar
Pierre Stock committed
351
        params_dict = dict(self.named_parameters())
352
        loaded_params: set[str] = set()
353
        expert_params_mapping = self.get_expert_mapping()
354
        for name, loaded_weight in weights:
355
356
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
357
                # Loading kv cache quantization scales
358
359
360
361
362
363
364
365
366
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

Pierre Stock's avatar
Pierre Stock committed
367
368
369
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
370
371
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
372
373
                if ((name.endswith(".bias") or name.endswith("_bias"))
                        and name not in params_dict):
CHU Tianxiang's avatar
CHU Tianxiang committed
374
                    continue
375
376
377
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
378
379
380
381
382
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
CHU Tianxiang's avatar
CHU Tianxiang committed
383
                param = params_dict[name]
Pierre Stock's avatar
Pierre Stock committed
384
385
386
387
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
388
389
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
Philipp Moritz's avatar
Philipp Moritz committed
390
391
392
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
393
394
395
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
396
397
398
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
                        continue
Philipp Moritz's avatar
Philipp Moritz committed
399
400
401
402
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
403
                                  name,
404
                                  shard_id=shard_id,
Philipp Moritz's avatar
Philipp Moritz committed
405
406
407
408
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
409
410
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
Philipp Moritz's avatar
Philipp Moritz committed
411
                        continue
412
413
414
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
415
                    # Remapping the name of FP8 kv-scale.
416
417
418
419
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

Philipp Moritz's avatar
Philipp Moritz committed
420
421
422
423
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
424
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
425
        
王敏's avatar
王敏 committed
426
427
428
429
430
431
432
433
434
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "block_sparse_moe.gate.weight",
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "lm_head.weight",
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
435
436
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
437
                os.environ['LM_NN'] = '0' 
zhuwenwen's avatar
zhuwenwen committed
438

王敏's avatar
王敏 committed
439
440
441
442
443
444
445
446
447
448
                matches = re.findall(combined_words, layername)

                if matches:
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
zhuwenwen's avatar
zhuwenwen committed
449
                    
450
        return loaded_params
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524


class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    fall_back_to_pt_during_load = False

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        self.config = config
        self.lora_config = lora_config
        self.quant_config = quant_config

        self.model = MixtralModel(vllm_config=vllm_config,
                                  prefix=maybe_prefix(prefix, "model"))
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
            quant_config=quant_config,
        )
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

525
526
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
527
        loader = AutoWeightsLoader(self)
528
        return loader.load_weights(weights)
529
530
531

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()