mixtral.py 21.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Pierre Stock's avatar
Pierre Stock committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
王敏's avatar
王敏 committed
26
27
import os
import re
zhuwenwen's avatar
zhuwenwen committed
28
29
from typing import Iterable, Optional, Union

30
31
from collections.abc import Iterable
from typing import Optional, Union
Pierre Stock's avatar
Pierre Stock committed
32
33
34

import torch
from torch import nn
35
from transformers import MixtralConfig
Pierre Stock's avatar
Pierre Stock committed
36

37
from vllm.attention import Attention
38
from vllm.compilation.decorators import support_torch_compile
39
from vllm.config import CacheConfig, VllmConfig
40
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
41
from vllm.model_executor.layers.fused_moe import FusedMoE
Pierre Stock's avatar
Pierre Stock committed
42
from vllm.model_executor.layers.layernorm import RMSNorm
43
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Philipp Moritz's avatar
Philipp Moritz committed
44
                                               ReplicatedLinear,
Pierre Stock's avatar
Pierre Stock committed
45
                                               RowParallelLinear)
46
from vllm.model_executor.layers.logits_processor import LogitsProcessor
47
from vllm.model_executor.layers.quantization import QuantizationConfig
48
from vllm.model_executor.layers.rotary_embedding import get_rope
Pierre Stock's avatar
Pierre Stock committed
49
from vllm.model_executor.layers.vocab_parallel_embedding import (
50
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
51
52
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
Pierre Stock's avatar
Pierre Stock committed
53
from vllm.model_executor.sampling_metadata import SamplingMetadata
54
from vllm.sequence import IntermediateTensors
王敏's avatar
王敏 committed
55
56
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
Pierre Stock's avatar
Pierre Stock committed
57

58
from .interfaces import SupportsLoRA, SupportsPP
59
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
60
61
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
62

Pierre Stock's avatar
Pierre Stock committed
63

Philipp Moritz's avatar
Philipp Moritz committed
64
65
66
67
68
69
70
71
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
72

73
74
75
76
77
78
79
    def __init__(self,
                 num_experts: int,
                 top_k: int,
                 hidden_size: int,
                 intermediate_size: int,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
80
                 tp_size: Optional[int] = None,
81
                 dp_size: Optional[int] = None,
82
                 prefix: str = ""):
83
        super().__init__()
Philipp Moritz's avatar
Philipp Moritz committed
84
        self.hidden_size = hidden_size
85

86
        # Gate always runs at half / full precision for now.
87

88
89
        self.gate = ReplicatedLinear(hidden_size,
                                     num_experts,
90
                                     bias=False,
91
                                     params_dtype=params_dtype,
92
93
                                     quant_config=None,
                                     prefix=f"{prefix}.gate")
94

95
96
97
98
99
100
101
102
        self.experts = FusedMoE(num_experts=num_experts,
                                top_k=top_k,
                                hidden_size=hidden_size,
                                intermediate_size=intermediate_size,
                                params_dtype=params_dtype,
                                reduce_results=True,
                                renormalize=True,
                                quant_config=quant_config,
103
                                tp_size=tp_size,
104
                                dp_size=dp_size,
105
                                prefix=f"{prefix}.experts")
106

107
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
108
109
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
Philipp Moritz's avatar
Philipp Moritz committed
110
        hidden_states = hidden_states.view(-1, self.hidden_size)
111
        # router_logits: (num_tokens, n_experts)
112
        router_logits, _ = self.gate(hidden_states)
113
        final_hidden_states = self.experts(hidden_states, router_logits)
114
        return final_hidden_states.view(orig_shape)
Pierre Stock's avatar
Pierre Stock committed
115
116
117
118


class MixtralAttention(nn.Module):

119
120
    def __init__(
        self,
121
        config: MixtralConfig,
122
123
124
125
126
127
128
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
129
        prefix: str = "",
130
    ) -> None:
Pierre Stock's avatar
Pierre Stock committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
147
        # MixtralConfig has an optional head_dim argument
148
149
150
        self.head_dim = getattr(config, "head_dim", None)
        if self.head_dim is None:
            self.head_dim = self.hidden_size // self.total_num_heads
Pierre Stock's avatar
Pierre Stock committed
151
152
153
154
155
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

156
        self.qkv_proj = QKVParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
157
158
159
160
161
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
162
            quant_config=quant_config,
163
            prefix=f"{prefix}.qkv_proj",
Pierre Stock's avatar
Pierre Stock committed
164
        )
165
        self.o_proj = RowParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
166
167
168
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
169
            quant_config=quant_config,
170
            prefix=f"{prefix}.o_proj",
Pierre Stock's avatar
Pierre Stock committed
171
172
173
174
175
176
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
177
            is_neox_style=True,
Pierre Stock's avatar
Pierre Stock committed
178
        )
179
180
181
182
183
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
184
185
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Pierre Stock's avatar
Pierre Stock committed
186
187
188
189
190
191

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
192
        qkv, _ = self.qkv_proj(hidden_states)
Pierre Stock's avatar
Pierre Stock committed
193
194
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
195
        attn_output = self.attn(q, k, v)
196
        output, _ = self.o_proj(attn_output)
Pierre Stock's avatar
Pierre Stock committed
197
198
199
200
201
202
203
        return output


class MixtralDecoderLayer(nn.Module):

    def __init__(
        self,
204
        config: MixtralConfig,
205
        cache_config: Optional[CacheConfig] = None,
206
        quant_config: Optional[QuantizationConfig] = None,
207
        prefix: str = "",
Pierre Stock's avatar
Pierre Stock committed
208
209
210
211
212
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
213
        self.self_attn = MixtralAttention(
214
            config=config,
Pierre Stock's avatar
Pierre Stock committed
215
216
217
218
219
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
220
            cache_config=cache_config,
221
222
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn")
Philipp Moritz's avatar
Philipp Moritz committed
223
224
225
226
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
227
            intermediate_size=config.intermediate_size,
228
229
            quant_config=quant_config,
            prefix=f"{prefix}.block_sparse_moe")
230
231
232
233
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
234
235
236
237

    def forward(
        self,
        positions: torch.Tensor,
238
239
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
240
    ) -> torch.Tensor:
241
242
243
244
245
246
247
248
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
Pierre Stock's avatar
Pierre Stock committed
249
            positions=positions,
250
            hidden_states=hidden_states,
Pierre Stock's avatar
Pierre Stock committed
251
252
        )

253
254
255
256
257
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual
Pierre Stock's avatar
Pierre Stock committed
258

259

260
@support_torch_compile
261
class MixtralModel(nn.Module):
Pierre Stock's avatar
Pierre Stock committed
262

263
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Pierre Stock's avatar
Pierre Stock committed
264
        super().__init__()
265
266
267
268
269
270

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

271
272
        self.config = config
        self.quant_config = quant_config
273
274
275
276
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
277
278

        self.embed_tokens = VocabParallelEmbedding(
279
            self.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
280
            config.hidden_size,
281
            org_num_embeddings=config.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
282
        )
283
284

        self.start_layer, self.end_layer, self.layers = make_layers(
285
286
287
288
289
            config.num_hidden_layers,
            lambda prefix: MixtralDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers")
290

291
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
292
293
294
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
295
296
297
298
299
300
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()

        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
Pierre Stock's avatar
Pierre Stock committed
301

302
303
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)
Pierre Stock's avatar
Pierre Stock committed
304
305
306
307
308

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
309
        intermediate_tensors: Optional[IntermediateTensors],
310
        inputs_embeds: Optional[torch.Tensor] = None,
311
    ) -> Union[torch.Tensor, IntermediateTensors]:
312
        if get_pp_group().is_first_rank:
313
314
315
316
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
317
318
319
320
321
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
322
323
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states, residual = layer(positions, hidden_states, residual)
324
325
326
327
328
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
329
330
331
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

332
333
334
335
336
337
338
339
340
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts)

341
342
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
Pierre Stock's avatar
Pierre Stock committed
343
344
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
345
346
347
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
Pierre Stock's avatar
Pierre Stock committed
348
        ]
349

Pierre Stock's avatar
Pierre Stock committed
350
        params_dict = dict(self.named_parameters())
351
        loaded_params: set[str] = set()
352
        expert_params_mapping = self.get_expert_mapping()
353
        for name, loaded_weight in weights:
354
355
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
356
                # Loading kv cache quantization scales
357
358
359
360
361
362
363
364
365
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

Pierre Stock's avatar
Pierre Stock committed
366
367
368
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
369
370
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
371
372
                if ((name.endswith(".bias") or name.endswith("_bias"))
                        and name not in params_dict):
CHU Tianxiang's avatar
CHU Tianxiang committed
373
                    continue
374
375
376
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
377
378
379
380
381
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
CHU Tianxiang's avatar
CHU Tianxiang committed
382
                param = params_dict[name]
Pierre Stock's avatar
Pierre Stock committed
383
384
385
386
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
387
388
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
Philipp Moritz's avatar
Philipp Moritz committed
389
390
391
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
392
393
394
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
395
396
397
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
                        continue
Philipp Moritz's avatar
Philipp Moritz committed
398
399
400
401
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
402
                                  name,
403
                                  shard_id=shard_id,
Philipp Moritz's avatar
Philipp Moritz committed
404
405
406
407
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
408
409
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
Philipp Moritz's avatar
Philipp Moritz committed
410
                        continue
411
412
413
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
414
                    # Remapping the name of FP8 kv-scale.
415
416
417
418
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

Philipp Moritz's avatar
Philipp Moritz committed
419
420
421
422
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
423
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
424
        
王敏's avatar
王敏 committed
425
426
427
428
429
430
431
432
433
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "block_sparse_moe.gate.weight",
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "lm_head.weight",
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
434
435
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
436
                os.environ['LM_NN'] = '0' 
zhuwenwen's avatar
zhuwenwen committed
437

王敏's avatar
王敏 committed
438
439
440
441
442
443
444
445
446
447
                matches = re.findall(combined_words, layername)

                if matches:
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
zhuwenwen's avatar
zhuwenwen committed
448
                    
449
        return loaded_params
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523


class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    fall_back_to_pt_during_load = False

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        self.config = config
        self.lora_config = lora_config
        self.quant_config = quant_config

        self.model = MixtralModel(vllm_config=vllm_config,
                                  prefix=maybe_prefix(prefix, "model"))
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
            quant_config=quant_config,
        )
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

524
525
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
526
        loader = AutoWeightsLoader(self)
527
        return loader.load_weights(weights)
528
529
530

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()