mixtral.py 21.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

Pierre Stock's avatar
Pierre Stock committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
王敏's avatar
王敏 committed
25
26
import os
import re
27
from typing import Iterable, Optional, Set, Tuple, Union
Pierre Stock's avatar
Pierre Stock committed
28
29
30

import torch
from torch import nn
31
from transformers import MixtralConfig
Pierre Stock's avatar
Pierre Stock committed
32

33
from vllm.attention import Attention
34
from vllm.compilation.decorators import support_torch_compile
35
from vllm.config import CacheConfig, VllmConfig
36
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
37
from vllm.model_executor.layers.fused_moe import FusedMoE
Pierre Stock's avatar
Pierre Stock committed
38
from vllm.model_executor.layers.layernorm import RMSNorm
39
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Philipp Moritz's avatar
Philipp Moritz committed
40
                                               ReplicatedLinear,
Pierre Stock's avatar
Pierre Stock committed
41
                                               RowParallelLinear)
42
from vllm.model_executor.layers.logits_processor import LogitsProcessor
43
from vllm.model_executor.layers.quantization import QuantizationConfig
44
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
45
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Pierre Stock's avatar
Pierre Stock committed
46
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
48
49
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
Pierre Stock's avatar
Pierre Stock committed
50
from vllm.model_executor.sampling_metadata import SamplingMetadata
51
from vllm.sequence import IntermediateTensors
王敏's avatar
王敏 committed
52
53
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
Pierre Stock's avatar
Pierre Stock committed
54

55
from .interfaces import SupportsLoRA, SupportsPP
56
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
57
58
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
59

Pierre Stock's avatar
Pierre Stock committed
60

Philipp Moritz's avatar
Philipp Moritz committed
61
62
63
64
65
66
67
68
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
69

70
71
72
73
74
75
76
    def __init__(self,
                 num_experts: int,
                 top_k: int,
                 hidden_size: int,
                 intermediate_size: int,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
77
                 tp_size: Optional[int] = None,
78
                 dp_size: Optional[int] = None,
79
                 prefix: str = ""):
80
        super().__init__()
Philipp Moritz's avatar
Philipp Moritz committed
81
        self.hidden_size = hidden_size
82

83
        # Gate always runs at half / full precision for now.
84

85
86
        self.gate = ReplicatedLinear(hidden_size,
                                     num_experts,
87
                                     bias=False,
88
                                     params_dtype=params_dtype,
89
90
                                     quant_config=None,
                                     prefix=f"{prefix}.gate")
91

92
93
94
95
96
97
98
99
        self.experts = FusedMoE(num_experts=num_experts,
                                top_k=top_k,
                                hidden_size=hidden_size,
                                intermediate_size=intermediate_size,
                                params_dtype=params_dtype,
                                reduce_results=True,
                                renormalize=True,
                                quant_config=quant_config,
100
                                tp_size=tp_size,
101
                                dp_size=dp_size,
102
                                prefix=f"{prefix}.experts")
103

104
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
105
106
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
Philipp Moritz's avatar
Philipp Moritz committed
107
        hidden_states = hidden_states.view(-1, self.hidden_size)
108
        # router_logits: (num_tokens, n_experts)
109
        router_logits, _ = self.gate(hidden_states)
110
        final_hidden_states = self.experts(hidden_states, router_logits)
111
        return final_hidden_states.view(orig_shape)
Pierre Stock's avatar
Pierre Stock committed
112
113
114
115


class MixtralAttention(nn.Module):

116
117
    def __init__(
        self,
118
        config: MixtralConfig,
119
120
121
122
123
124
125
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
126
        prefix: str = "",
127
    ) -> None:
Pierre Stock's avatar
Pierre Stock committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
144
145
146
        # MixtralConfig has an optional head_dim argument
        self.head_dim = getattr(config, "head_dim",
                                self.hidden_size // self.total_num_heads)
Pierre Stock's avatar
Pierre Stock committed
147
148
149
150
151
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

152
        self.qkv_proj = QKVParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
153
154
155
156
157
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
158
            quant_config=quant_config,
159
            prefix=f"{prefix}.qkv_proj",
Pierre Stock's avatar
Pierre Stock committed
160
        )
161
        self.o_proj = RowParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
162
163
164
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
165
            quant_config=quant_config,
166
            prefix=f"{prefix}.o_proj",
Pierre Stock's avatar
Pierre Stock committed
167
168
169
170
171
172
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
173
            is_neox_style=True,
Pierre Stock's avatar
Pierre Stock committed
174
        )
175
176
177
178
179
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
180
181
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Pierre Stock's avatar
Pierre Stock committed
182
183
184
185
186
187

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
188
        qkv, _ = self.qkv_proj(hidden_states)
Pierre Stock's avatar
Pierre Stock committed
189
190
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
191
        attn_output = self.attn(q, k, v)
192
        output, _ = self.o_proj(attn_output)
Pierre Stock's avatar
Pierre Stock committed
193
194
195
196
197
198
199
        return output


class MixtralDecoderLayer(nn.Module):

    def __init__(
        self,
200
        config: MixtralConfig,
201
        cache_config: Optional[CacheConfig] = None,
202
        quant_config: Optional[QuantizationConfig] = None,
203
        prefix: str = "",
Pierre Stock's avatar
Pierre Stock committed
204
205
206
207
208
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
209
        self.self_attn = MixtralAttention(
210
            config=config,
Pierre Stock's avatar
Pierre Stock committed
211
212
213
214
215
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
216
            cache_config=cache_config,
217
218
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn")
Philipp Moritz's avatar
Philipp Moritz committed
219
220
221
222
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
223
            intermediate_size=config.intermediate_size,
224
225
            quant_config=quant_config,
            prefix=f"{prefix}.block_sparse_moe")
226
227
228
229
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
230
231
232
233

    def forward(
        self,
        positions: torch.Tensor,
234
235
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
236
    ) -> torch.Tensor:
237
238
239
240
241
242
243
244
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
Pierre Stock's avatar
Pierre Stock committed
245
            positions=positions,
246
            hidden_states=hidden_states,
Pierre Stock's avatar
Pierre Stock committed
247
248
        )

249
250
251
252
253
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual
Pierre Stock's avatar
Pierre Stock committed
254

255

256
@support_torch_compile
257
class MixtralModel(nn.Module):
Pierre Stock's avatar
Pierre Stock committed
258

259
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Pierre Stock's avatar
Pierre Stock committed
260
        super().__init__()
261
262
263
264
265
266

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

267
268
        self.config = config
        self.quant_config = quant_config
269
270
271
272
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
273
274

        self.embed_tokens = VocabParallelEmbedding(
275
            self.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
276
            config.hidden_size,
277
            org_num_embeddings=config.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
278
        )
279
280

        self.start_layer, self.end_layer, self.layers = make_layers(
281
282
283
284
285
            config.num_hidden_layers,
            lambda prefix: MixtralDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers")
286

287
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
288
289
290
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
291
292
293
294
295
296
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()

        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
Pierre Stock's avatar
Pierre Stock committed
297

298
299
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)
Pierre Stock's avatar
Pierre Stock committed
300
301
302
303
304

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
305
        intermediate_tensors: Optional[IntermediateTensors],
306
        inputs_embeds: Optional[torch.Tensor] = None,
307
    ) -> Union[torch.Tensor, IntermediateTensors]:
308
        if get_pp_group().is_first_rank:
309
310
311
312
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
313
314
315
316
317
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
318
319
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states, residual = layer(positions, hidden_states, residual)
320
321
322
323
324
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
325
326
327
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

328
329
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
Pierre Stock's avatar
Pierre Stock committed
330
331
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
332
333
334
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
Pierre Stock's avatar
Pierre Stock committed
335
        ]
336

337
338
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
王敏's avatar
王敏 committed
339
340
341
342
343
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts)
Philipp Moritz's avatar
Philipp Moritz committed
344

Pierre Stock's avatar
Pierre Stock committed
345
        params_dict = dict(self.named_parameters())
346
        loaded_params: Set[str] = set()
347
        for name, loaded_weight in weights:
348
349
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
350
                # Loading kv cache quantization scales
351
352
353
354
355
356
357
358
359
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

Pierre Stock's avatar
Pierre Stock committed
360
361
362
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
363
364
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
365
366
                if ((name.endswith(".bias") or name.endswith("_bias"))
                        and name not in params_dict):
CHU Tianxiang's avatar
CHU Tianxiang committed
367
                    continue
368
369
370
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
371
372
373
374
375
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
CHU Tianxiang's avatar
CHU Tianxiang committed
376
                param = params_dict[name]
Pierre Stock's avatar
Pierre Stock committed
377
378
379
380
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
381
382
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
Philipp Moritz's avatar
Philipp Moritz committed
383
384
385
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
386
387
388
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
389
390
391
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
                        continue
Philipp Moritz's avatar
Philipp Moritz committed
392
393
394
395
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
396
                                  name,
397
                                  shard_id=shard_id,
Philipp Moritz's avatar
Philipp Moritz committed
398
399
400
401
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
402
403
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
Philipp Moritz's avatar
Philipp Moritz committed
404
                        continue
405
406
407
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
408
                    # Remapping the name of FP8 kv-scale.
409
410
411
412
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

Philipp Moritz's avatar
Philipp Moritz committed
413
414
415
416
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
417
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
418
        
王敏's avatar
王敏 committed
419
420
421
422
423
424
425
426
427
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "block_sparse_moe.gate.weight",
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "lm_head.weight",
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
428
429
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
430
                os.environ['LM_NN'] = '0' 
zhuwenwen's avatar
zhuwenwen committed
431

王敏's avatar
王敏 committed
432
433
434
435
436
437
438
439
440
441
                matches = re.findall(combined_words, layername)

                if matches:
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
zhuwenwen's avatar
zhuwenwen committed
442
                    
443
        return loaded_params
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530


class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    fall_back_to_pt_during_load = False

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        self.config = config
        self.lora_config = lora_config
        self.quant_config = quant_config

        self.model = MixtralModel(vllm_config=vllm_config,
                                  prefix=maybe_prefix(prefix, "model"))
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
            quant_config=quant_config,
        )
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
        self.sampler = get_sampler()
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: Optional[torch.Tensor],
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
        return loader.load_weights(weights)